chendl commited on
Commit
1810078
1 Parent(s): 9acf8b9

update cap

Browse files
app.py CHANGED
@@ -2,18 +2,16 @@ import os
2
  import sys
3
  from pathlib import Path
4
  # os.system("cd transformers && pip install .")
5
- os.system("cd multimodal && pip install .")
6
- os.system("cd multimodal/YOLOX && pip install .")
7
  import numpy as np
8
  import torch
9
  from PIL import Image
10
  import tempfile
11
 
12
-
13
  import string
14
  import cv2
15
 
16
-
17
  import gradio as gr
18
  import torch
19
  from PIL import Image
@@ -52,34 +50,34 @@ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transfor
52
  enhance_data=False,
53
  )
54
 
55
-
56
- # checkpoint_path = "/home/aimos/huggingface/space/demo.pt"
57
  checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt")
58
  checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
59
  model_state_dict = {}
60
  for key in checkpoint.keys():
61
  model_state_dict[key.replace("module.", "")] = checkpoint[key]
62
- if "vision_encoder.logit_scale"in model_state_dict:
63
  # previous checkpoint has some unnecessary weights
64
  del model_state_dict["vision_encoder.logit_scale"]
65
  del model_state_dict["vision_encoder.visual.proj"]
66
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
67
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
68
  flamingo.load_state_dict(model_state_dict, strict=True)
69
- chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size )
 
70
 
71
  def get_outputs(
72
- model,
73
- batch_images,
74
- attention_mask,
75
- max_generation_length,
76
- min_generation_length,
77
- num_beams,
78
- length_penalty,
79
- input_ids,
80
- image_start_index_list=None,
81
- image_nums=None,
82
- bad_words_ids=None,
83
  ):
84
  # and torch.cuda.amp.autocast(dtype=torch.float16)
85
  with torch.inference_mode():
@@ -109,15 +107,13 @@ def get_outputs(
109
  return outputs
110
 
111
 
112
-
113
-
114
  def generate(
115
- idx,
116
- image,
117
- text,
118
- vis_embed_size=256,
119
- rank=0,
120
- world_size=1,
121
  ):
122
  if image is None:
123
  raise gr.Error("Please upload an image.")
@@ -138,7 +134,8 @@ def generate(
138
  image = image.resize((224, 224))
139
  batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
140
  if idx == 1:
141
- prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
 
142
  bad_words_ids = None
143
  max_generation_length = 5
144
  else:
@@ -174,14 +171,14 @@ def generate(
174
  boxes = outputs["boxes"]
175
  scores = outputs["scores"]
176
  if len(scores) > 0:
177
- box = boxes[scores.argmax()]/224
178
  print(f"{box}")
179
 
180
  if idx == 1:
181
  open_cv_image = np.array(image_ori)
182
  # Convert RGB to BGR
183
  open_cv_image = open_cv_image[:, :, ::-1].copy()
184
- box = box*[width,height,width,height]
185
  # for box in boxes:
186
  open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
187
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
@@ -199,6 +196,7 @@ description = """<h3>This is the demo of Compositional-VLM. Upload your images a
199
  article = """<div style='display:flex; gap: 0.25rem; '><a href='https://compositionalvlm.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
200
  """
201
 
 
202
  # TODO show examples below
203
 
204
  # ========================================
@@ -217,16 +215,17 @@ def gradio_reset(chat_state, img_list):
217
 
218
  def build_image(image):
219
  if image is None:
220
- return None
221
  # res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8)
222
  from torchvision.transforms import ToPILImage
223
  # res = ToPILImage()(res)
224
  _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR)
225
  image.save(path)
226
 
227
- return path
 
228
 
229
- def upload_img(gr_img, text_input, chat_state,chatbot):
230
  if gr_img is None:
231
  return None, None, gr.update(interactive=True), chat_state, None
232
  chat_state = []
@@ -235,42 +234,42 @@ def upload_img(gr_img, text_input, chat_state,chatbot):
235
  chatbot = chatbot + [[(path,), None]]
236
  llm_message = chat.upload_img(gr_img, chat_state, img_list)
237
  return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
238
- value="Start Chatting", interactive=False), chat_state, img_list,chatbot
239
 
240
 
241
- def gradio_ask(user_message, chatbot, chat_state,radio):
242
  if len(user_message) == 0:
243
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
244
 
245
-
246
- chat.ask(user_message, chat_state,radio)
247
  chatbot = chatbot + [[user_message, None]]
248
  return chatbot, chat_state
249
 
250
 
251
- def gradio_answer(chatbot, chat_state, img_list, radio, text,num_beams, temperature):
252
  image = None
253
- llm_message,image = \
254
- chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
255
- max_length=2000,radio = radio,text_input = text)
256
-
257
  chatbot[-1][1] = llm_message
258
- if chat_state[-1]["from"]=="gpt":
259
  chat_state[-1]["value"] = llm_message
260
- if image==None:
261
  return "", chatbot, chat_state, img_list
262
  else:
263
  path = build_image(image)
264
- chatbot = chatbot + [[None,(path,)]]
265
  return "", chatbot, chat_state, img_list
266
 
 
267
  task_template = {
268
- "Cap": "Summarize the content of the photo <image>.",
269
- "VQA": "For this image <image>, I want a simple and direct answer to my question: <question>",
270
- "REC": "Can you point out <expr> in the image <image> and provide the coordinates of its location?",
271
- "GC": "Can you give me a description of the region <boxes> in image <image>?",
272
- "Advanced": "<question>",
273
- }
274
 
275
  with gr.Blocks() as demo:
276
  gr.Markdown(title)
@@ -310,24 +309,25 @@ with gr.Blocks() as demo:
310
  img_list = gr.State()
311
  chatbot = gr.Chatbot(label='Compositional-VLM')
312
 
313
-
314
  # template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False,
315
  # value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.')
316
  # text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image first, then input...", lines=3,
317
  # value=None, visible=False, interactive=False)
318
 
319
- text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...', interactive=False)
 
320
 
321
- upload_button.click(upload_img, [image, text_input, chat_state,chatbot],
322
- [image, text_input, upload_button, chat_state, img_list,chatbot])
323
 
324
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state,radio], [chatbot, chat_state]).then(
325
- gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature], [text_input,chatbot, chat_state, img_list]
 
326
  )
327
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
328
  queue=False)
329
 
330
- demo.launch(enable_queue=True,share=True)
331
  #
332
  # with gr.Blocks() as demo:
333
  # gr.Markdown(
 
2
  import sys
3
  from pathlib import Path
4
  # os.system("cd transformers && pip install .")
5
+ os.system("cd multimodal && pip install -e .")
6
+
7
  import numpy as np
8
  import torch
9
  from PIL import Image
10
  import tempfile
11
 
 
12
  import string
13
  import cv2
14
 
 
15
  import gradio as gr
16
  import torch
17
  from PIL import Image
 
50
  enhance_data=False,
51
  )
52
 
53
+ model_name = "pythiaS"
 
54
  checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt")
55
  checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
56
  model_state_dict = {}
57
  for key in checkpoint.keys():
58
  model_state_dict[key.replace("module.", "")] = checkpoint[key]
59
+ if "vision_encoder.logit_scale" in model_state_dict:
60
  # previous checkpoint has some unnecessary weights
61
  del model_state_dict["vision_encoder.logit_scale"]
62
  del model_state_dict["vision_encoder.visual.proj"]
63
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
64
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
65
  flamingo.load_state_dict(model_state_dict, strict=True)
66
+ chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size)
67
+
68
 
69
  def get_outputs(
70
+ model,
71
+ batch_images,
72
+ attention_mask,
73
+ max_generation_length,
74
+ min_generation_length,
75
+ num_beams,
76
+ length_penalty,
77
+ input_ids,
78
+ image_start_index_list=None,
79
+ image_nums=None,
80
+ bad_words_ids=None,
81
  ):
82
  # and torch.cuda.amp.autocast(dtype=torch.float16)
83
  with torch.inference_mode():
 
107
  return outputs
108
 
109
 
 
 
110
  def generate(
111
+ idx,
112
+ image,
113
+ text,
114
+ vis_embed_size=256,
115
+ rank=0,
116
+ world_size=1,
117
  ):
118
  if image is None:
119
  raise gr.Error("Please upload an image.")
 
134
  image = image.resize((224, 224))
135
  batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
136
  if idx == 1:
137
+ prompt = [
138
+ f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
139
  bad_words_ids = None
140
  max_generation_length = 5
141
  else:
 
171
  boxes = outputs["boxes"]
172
  scores = outputs["scores"]
173
  if len(scores) > 0:
174
+ box = boxes[scores.argmax()] / 224
175
  print(f"{box}")
176
 
177
  if idx == 1:
178
  open_cv_image = np.array(image_ori)
179
  # Convert RGB to BGR
180
  open_cv_image = open_cv_image[:, :, ::-1].copy()
181
+ box = box * [width, height, width, height]
182
  # for box in boxes:
183
  open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
184
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
 
196
  article = """<div style='display:flex; gap: 0.25rem; '><a href='https://compositionalvlm.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
197
  """
198
 
199
+
200
  # TODO show examples below
201
 
202
  # ========================================
 
215
 
216
  def build_image(image):
217
  if image is None:
218
+ return None
219
  # res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8)
220
  from torchvision.transforms import ToPILImage
221
  # res = ToPILImage()(res)
222
  _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR)
223
  image.save(path)
224
 
225
+ return path
226
+
227
 
228
+ def upload_img(gr_img, text_input, chat_state, chatbot):
229
  if gr_img is None:
230
  return None, None, gr.update(interactive=True), chat_state, None
231
  chat_state = []
 
234
  chatbot = chatbot + [[(path,), None]]
235
  llm_message = chat.upload_img(gr_img, chat_state, img_list)
236
  return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
237
+ value="Start Chatting", interactive=False), chat_state, img_list, chatbot
238
 
239
 
240
+ def gradio_ask(user_message, chatbot, chat_state, radio):
241
  if len(user_message) == 0:
242
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
243
 
244
+ chat.ask(user_message, chat_state, radio, model_name)
 
245
  chatbot = chatbot + [[user_message, None]]
246
  return chatbot, chat_state
247
 
248
 
249
+ def gradio_answer(chatbot, chat_state, img_list, radio, text, num_beams, temperature):
250
  image = None
251
+ llm_message, image = \
252
+ chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
253
+ max_length=2000, radio=radio, text_input=text, model_name=model_name)
254
+
255
  chatbot[-1][1] = llm_message
256
+ if chat_state[-1]["from"] == "gpt":
257
  chat_state[-1]["value"] = llm_message
258
+ if image == None:
259
  return "", chatbot, chat_state, img_list
260
  else:
261
  path = build_image(image)
262
+ chatbot = chatbot + [[None, (path,)]]
263
  return "", chatbot, chat_state, img_list
264
 
265
+
266
  task_template = {
267
+ "Cap": "Summarize the content of the photo <image>.",
268
+ "VQA": "For this image <image>, I want a simple and direct answer to my question: <question>",
269
+ "REC": "Can you point out <expr> in the image <image> and provide the coordinates of its location?",
270
+ "GC": "Can you give me a description of the region <boxes> in image <image>?",
271
+ "Advanced": "<question>",
272
+ }
273
 
274
  with gr.Blocks() as demo:
275
  gr.Markdown(title)
 
309
  img_list = gr.State()
310
  chatbot = gr.Chatbot(label='Compositional-VLM')
311
 
 
312
  # template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False,
313
  # value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.')
314
  # text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image first, then input...", lines=3,
315
  # value=None, visible=False, interactive=False)
316
 
317
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...',
318
+ interactive=False)
319
 
320
+ upload_button.click(upload_img, [image, text_input, chat_state, chatbot],
321
+ [image, text_input, upload_button, chat_state, img_list, chatbot])
322
 
323
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then(
324
+ gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
325
+ [text_input, chatbot, chat_state, img_list]
326
  )
327
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
328
  queue=False)
329
 
330
+ demo.launch(share=True)
331
  #
332
  # with gr.Blocks() as demo:
333
  # gr.Markdown(
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download, login
22
  from open_flamingo.src.factory import create_model_and_transforms
23
  from open_flamingo.eval.task.caption_chat import captioner
24
 
 
25
  class SeparatorStyle(Enum):
26
  """Different separator style."""
27
  SINGLE = auto()
@@ -125,18 +126,19 @@ CONV_VISION = Conversation(
125
  sep="###",
126
  )
127
 
 
128
  def get_outputs(
129
- model,
130
- batch_images,
131
- attention_mask,
132
- max_generation_length,
133
- min_generation_length,
134
- num_beams,
135
- length_penalty,
136
- input_ids,
137
- image_start_index_list=None,
138
- image_nums=None,
139
- bad_words_ids=None,
140
  ):
141
  # and torch.cuda.amp.autocast(dtype=torch.float16)
142
  with torch.inference_mode():
@@ -165,16 +167,17 @@ def get_outputs(
165
 
166
  return outputs
167
 
 
168
  def generate(
169
- idx,
170
- image,
171
- text,
172
- image_processor,
173
- tokenizer,
174
- flamingo,
175
- vis_embed_size=256,
176
- rank=0,
177
- world_size=1,
178
  ):
179
  if image is None:
180
  raise gr.Error("Please upload an image.")
@@ -195,7 +198,8 @@ def generate(
195
  image = image.resize((224, 224))
196
  batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
197
  if idx == 1:
198
- prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
 
199
  bad_words_ids = None
200
  max_generation_length = 5
201
  else:
@@ -231,15 +235,14 @@ def generate(
231
  boxes = outputs["boxes"]
232
  scores = outputs["scores"]
233
  if len(scores) > 0:
234
- box = boxes[scores.argmax()]/224
235
  print(f"{box}")
236
 
237
-
238
- if len(boxes)>0:
239
  open_cv_image = np.array(image_ori)
240
  # Convert RGB to BGR
241
  open_cv_image = open_cv_image[:, :, ::-1].copy()
242
- box = box*[width,height,width,height]
243
  # for box in boxes:
244
  open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
245
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
@@ -248,6 +251,7 @@ def generate(
248
  gen_text = tokenizer.batch_decode(outputs)
249
  return (f"{gen_text}")
250
 
 
251
  def preprocess_conv(data):
252
  conversation = ""
253
  BEGIN_SIGNAL = "### "
@@ -263,14 +267,16 @@ def preprocess_conv(data):
263
  conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
264
  return conversation
265
 
 
266
  def preprocess_image(sample, image_processor):
267
  image = image_processor(sample)
268
  if isinstance(image, transformers.image_processing_utils.BatchFeature):
269
  image = torch.tensor(image["pixel_values"][0])
270
  return image
271
 
 
272
  class Chat:
273
- def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
274
  self.model = model
275
  self.vis_processor = vis_processor
276
  self.tokenizer = tokenizer
@@ -280,34 +286,41 @@ class Chat:
280
  # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
281
  # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
282
 
283
- def ask(self, text, conv,radio):
284
- if radio in ["Cap"]:
285
- conv.append({
286
- "from": "human",
287
- "value": "",
288
- })
289
- elif radio in ["VQA"]:
290
- conv.append({
291
- "from": "human",
292
- "value": f"Answer the question using a single word or phrase. {text}",
293
- })
294
- elif radio in ["REC"]:
295
- conv.append({
296
- "from": "human",
297
- "value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
298
- })
299
- else:
300
  conv.append({
301
  "from": "human",
302
  "value": text,
303
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
305
  # and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
306
  # conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
307
  # else:
308
  # conv.append_message(conv.roles[0], text)
309
 
310
- def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
 
311
  repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
312
  # conv.append_message(conv.roles[1], None)
313
  # embs = self.get_context_emb(conv, img_list)
@@ -358,10 +371,10 @@ class Chat:
358
  image = image.resize((size, size))
359
  print(f"image size: {image.size}")
360
  batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
361
-
362
  # conversation = []
363
  human_sentence = None
364
- if radio in ["Cap","VQA"]:
365
  conv.append({
366
  "from": "gpt",
367
  "value": "",
@@ -375,9 +388,9 @@ class Chat:
375
  )
376
  else:
377
  conv.append({
378
- "from": "gpt",
379
- "value": "",
380
- })
381
  # while True:
382
  # human_sentence = input("### Human: ")
383
  # if human_sentence == "#end#":
@@ -390,7 +403,11 @@ class Chat:
390
  # "from": "gpt",
391
  # "value": "",
392
  # })
393
- text = preprocess_conv(conv).strip()
 
 
 
 
394
  caption = f"<|#image#|>{self.tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
395
  encodings = self.tokenizer(
396
  caption,
@@ -406,7 +423,8 @@ class Chat:
406
  image_nums = [1] * len(input_ids)
407
  added_bbox_list = []
408
  if radio in ["Cap"]:
409
- output_text, out_image = captioner(self.model,self.tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list)
 
410
  else:
411
  with torch.inference_mode():
412
  text_outputs = self.model.generate(
@@ -439,7 +457,7 @@ class Chat:
439
  print(f"{box}")
440
  out_image = None
441
 
442
- if len(boxes)>0:
443
  width, height = image_ori.size
444
  open_cv_image = np.array(image_ori)
445
  # Convert RGB to BGR
@@ -449,7 +467,6 @@ class Chat:
449
  open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
450
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
451
 
452
-
453
  # output_token = outputs[0, input_ids.shape[1]:]
454
  # output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
455
  # conv[-1]["value"] = output_text
@@ -499,16 +516,17 @@ class Chat:
499
  # mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
500
  # mixed_embs = torch.cat(mixed_embs, dim=1)
501
  # return mixed_embs
502
-
 
503
  def evaluate_exp(
504
- model,
505
- tokenizer,
506
- image_processor,
507
- vis_embed_size=None,
508
- rank=0,
509
- world_size=1,
510
- id=0,
511
- add_visual=True,
512
  ):
513
  media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
514
  box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
@@ -541,7 +559,7 @@ def evaluate_exp(
541
  "value": "",
542
  })
543
  text = preprocess_conv(conversation).strip()
544
- caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
545
  encodings = tokenizer(
546
  caption,
547
  padding="longest",
@@ -569,3 +587,4 @@ def evaluate_exp(
569
 
570
 
571
 
 
 
22
  from open_flamingo.src.factory import create_model_and_transforms
23
  from open_flamingo.eval.task.caption_chat import captioner
24
 
25
+
26
  class SeparatorStyle(Enum):
27
  """Different separator style."""
28
  SINGLE = auto()
 
126
  sep="###",
127
  )
128
 
129
+
130
  def get_outputs(
131
+ model,
132
+ batch_images,
133
+ attention_mask,
134
+ max_generation_length,
135
+ min_generation_length,
136
+ num_beams,
137
+ length_penalty,
138
+ input_ids,
139
+ image_start_index_list=None,
140
+ image_nums=None,
141
+ bad_words_ids=None,
142
  ):
143
  # and torch.cuda.amp.autocast(dtype=torch.float16)
144
  with torch.inference_mode():
 
167
 
168
  return outputs
169
 
170
+
171
  def generate(
172
+ idx,
173
+ image,
174
+ text,
175
+ image_processor,
176
+ tokenizer,
177
+ flamingo,
178
+ vis_embed_size=256,
179
+ rank=0,
180
+ world_size=1,
181
  ):
182
  if image is None:
183
  raise gr.Error("Please upload an image.")
 
198
  image = image.resize((224, 224))
199
  batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
200
  if idx == 1:
201
+ prompt = [
202
+ f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
203
  bad_words_ids = None
204
  max_generation_length = 5
205
  else:
 
235
  boxes = outputs["boxes"]
236
  scores = outputs["scores"]
237
  if len(scores) > 0:
238
+ box = boxes[scores.argmax()] / 224
239
  print(f"{box}")
240
 
241
+ if len(boxes) > 0:
 
242
  open_cv_image = np.array(image_ori)
243
  # Convert RGB to BGR
244
  open_cv_image = open_cv_image[:, :, ::-1].copy()
245
+ box = box * [width, height, width, height]
246
  # for box in boxes:
247
  open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
248
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
 
251
  gen_text = tokenizer.batch_decode(outputs)
252
  return (f"{gen_text}")
253
 
254
+
255
  def preprocess_conv(data):
256
  conversation = ""
257
  BEGIN_SIGNAL = "### "
 
267
  conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
268
  return conversation
269
 
270
+
271
  def preprocess_image(sample, image_processor):
272
  image = image_processor(sample)
273
  if isinstance(image, transformers.image_processing_utils.BatchFeature):
274
  image = torch.tensor(image["pixel_values"][0])
275
  return image
276
 
277
+
278
  class Chat:
279
+ def __init__(self, model, vis_processor, tokenizer, vis_embed_size):
280
  self.model = model
281
  self.vis_processor = vis_processor
282
  self.tokenizer = tokenizer
 
286
  # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
287
  # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
288
 
289
+ def ask(self, text, conv, radio, model_name):
290
+ if "pythiaS" in model_name:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  conv.append({
292
  "from": "human",
293
  "value": text,
294
  })
295
+ else:
296
+ if radio in ["Cap"]:
297
+ conv.append({
298
+ "from": "human",
299
+ "value": "",
300
+ })
301
+ elif radio in ["VQA"]:
302
+ conv.append({
303
+ "from": "human",
304
+ "value": f"Answer the question using a single word or phrase. {text}",
305
+ })
306
+ elif radio in ["REC"]:
307
+ conv.append({
308
+ "from": "human",
309
+ "value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
310
+ })
311
+ else:
312
+ conv.append({
313
+ "from": "human",
314
+ "value": text,
315
+ })
316
  # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
317
  # and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
318
  # conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
319
  # else:
320
  # conv.append_message(conv.roles[0], text)
321
 
322
+ def answer(self, conv, img_list, radio, text_input, model_name, max_new_tokens=200, num_beams=5, min_length=1,
323
+ top_p=0.9,
324
  repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
325
  # conv.append_message(conv.roles[1], None)
326
  # embs = self.get_context_emb(conv, img_list)
 
371
  image = image.resize((size, size))
372
  print(f"image size: {image.size}")
373
  batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
374
+
375
  # conversation = []
376
  human_sentence = None
377
+ if radio in ["Cap", "VQA"]:
378
  conv.append({
379
  "from": "gpt",
380
  "value": "",
 
388
  )
389
  else:
390
  conv.append({
391
+ "from": "gpt",
392
+ "value": "",
393
+ })
394
  # while True:
395
  # human_sentence = input("### Human: ")
396
  # if human_sentence == "#end#":
 
403
  # "from": "gpt",
404
  # "value": "",
405
  # })
406
+ if "pythiaS" in model_name:
407
+ text = conv[-1]["value"].strip()
408
+ print(text)
409
+ else:
410
+ text = preprocess_conv(conv).strip()
411
  caption = f"<|#image#|>{self.tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
412
  encodings = self.tokenizer(
413
  caption,
 
423
  image_nums = [1] * len(input_ids)
424
  added_bbox_list = []
425
  if radio in ["Cap"]:
426
+ output_text, out_image = captioner(self.model, self.tokenizer, image_ori, batch_images, input_ids,
427
+ attention_mask, image_start_index_list, image_nums, added_bbox_list)
428
  else:
429
  with torch.inference_mode():
430
  text_outputs = self.model.generate(
 
457
  print(f"{box}")
458
  out_image = None
459
 
460
+ if len(boxes) > 0:
461
  width, height = image_ori.size
462
  open_cv_image = np.array(image_ori)
463
  # Convert RGB to BGR
 
467
  open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
468
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
469
 
 
470
  # output_token = outputs[0, input_ids.shape[1]:]
471
  # output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
472
  # conv[-1]["value"] = output_text
 
516
  # mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
517
  # mixed_embs = torch.cat(mixed_embs, dim=1)
518
  # return mixed_embs
519
+
520
+
521
  def evaluate_exp(
522
+ model,
523
+ tokenizer,
524
+ image_processor,
525
+ vis_embed_size=None,
526
+ rank=0,
527
+ world_size=1,
528
+ id=0,
529
+ add_visual=True,
530
  ):
531
  media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
532
  box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
 
559
  "value": "",
560
  })
561
  text = preprocess_conv(conversation).strip()
562
+ caption = f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text}"
563
  encodings = tokenizer(
564
  caption,
565
  padding="longest",
 
587
 
588
 
589
 
590
+
multimodal/open_flamingo/eval/task/caption_chat.py CHANGED
@@ -51,7 +51,7 @@ def prepare_batch_images(batch, image_processor):
51
 
52
 
53
  def captioner(
54
- model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
55
  """Evaluate a model on COCO dataset.
56
  Returns:
57
  float: CIDEr score
@@ -73,10 +73,23 @@ def captioner(
73
  object_token = "<|#object#|>"
74
  ori_prompt_length = len(input_ids[0])
75
  have_prebox = False
 
76
  while True:
77
  batch_images = batch_images
78
- input_ids = input_ids
79
- attention_mask = attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
80
  image_start_index_list = image_start_index_list
81
  image_nums = image_nums
82
  if debug:
@@ -148,6 +161,7 @@ def captioner(
148
  prompt += box_token + endofobject_token
149
  if debug:
150
  print("after inserting visual---->", prompt)
 
151
  else:
152
  import numpy as np
153
  import cv2
@@ -165,8 +179,8 @@ def captioner(
165
  if debug:
166
  print("after inserting previsual---->", prompt)
167
  else:
168
- if debug:
169
- import pdb;pdb.set_trace()
170
  prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
171
  else:
172
  break
@@ -414,4 +428,5 @@ def evaluate_coco_flickr(
414
  metrics = {}
415
  metrics["CIDEr"] = 0.0
416
 
 
417
  return metrics["CIDEr"]
 
51
 
52
 
53
  def captioner(
54
+ model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=True):
55
  """Evaluate a model on COCO dataset.
56
  Returns:
57
  float: CIDEr score
 
73
  object_token = "<|#object#|>"
74
  ori_prompt_length = len(input_ids[0])
75
  have_prebox = False
76
+ prompt = None
77
  while True:
78
  batch_images = batch_images
79
+ if prompt == None:
80
+ input_ids = input_ids
81
+ attention_mask = attention_mask
82
+ else:
83
+
84
+ encodings = tokenizer(
85
+ [prompt],
86
+ padding="longest",
87
+ truncation=True,
88
+ return_tensors="pt",
89
+ max_length=2000,
90
+ )
91
+ attention_mask = encodings["attention_mask"]
92
+ input_ids = encodings["input_ids"]
93
  image_start_index_list = image_start_index_list
94
  image_nums = image_nums
95
  if debug:
 
161
  prompt += box_token + endofobject_token
162
  if debug:
163
  print("after inserting visual---->", prompt)
164
+
165
  else:
166
  import numpy as np
167
  import cv2
 
179
  if debug:
180
  print("after inserting previsual---->", prompt)
181
  else:
182
+ # if debug:
183
+ # import pdb;pdb.set_trace()
184
  prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
185
  else:
186
  break
 
428
  metrics = {}
429
  metrics["CIDEr"] = 0.0
430
 
431
+
432
  return metrics["CIDEr"]