chendl commited on
Commit
df58d6d
1 Parent(s): 86468ab

update cap

Browse files
app.py CHANGED
@@ -248,7 +248,7 @@ def gradio_ask(user_message, chatbot, chat_state,radio):
248
 
249
 
250
  def gradio_answer(chatbot, chat_state, img_list, radio, text,num_beams, temperature):
251
- image == None
252
  llm_message,image = \
253
  chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
254
  max_length=2000,radio = radio,text_input = text)
 
248
 
249
 
250
  def gradio_answer(chatbot, chat_state, img_list, radio, text,num_beams, temperature):
251
+ image = None
252
  llm_message,image = \
253
  chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
254
  max_length=2000,radio = radio,text_input = text)
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -19,6 +19,7 @@ import gradio as gr
19
  from huggingface_hub import hf_hub_download, login
20
 
21
  from open_flamingo.src.factory import create_model_and_transforms
 
22
 
23
  class SeparatorStyle(Enum):
24
  """Different separator style."""
@@ -403,56 +404,59 @@ class Chat:
403
  image_start_index_list = [[x] for x in image_start_index_list]
404
  image_nums = [1] * len(input_ids)
405
  added_bbox_list = []
406
- with torch.inference_mode():
407
- text_outputs = self.model.generate(
408
- batch_images,
409
- input_ids,
410
- attention_mask=attention_mask,
411
- max_new_tokens=20,
412
- # min_new_tokens=8,
413
- num_beams=1,
414
- # length_penalty=0,
415
- image_start_index_list=image_start_index_list,
416
- image_nums=image_nums,
417
- added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
418
- )
419
- # and torch.cuda.amp.autocast(dtype=torch.float16)
420
- with torch.no_grad():
421
- outputs = self.model(
422
- vision_x=batch_images,
423
- lang_x=input_ids,
424
- attention_mask=attention_mask,
425
- image_nums=image_nums,
426
- image_start_index_list=image_start_index_list,
427
- added_bbox_list=None,
428
- add_box=False,
429
- )
430
- boxes = outputs["boxes"]
431
- scores = outputs["scores"]
432
- if len(scores) > 0:
433
- box = boxes[scores.argmax()] / 224
434
- print(f"{box}")
435
- out_image = None
436
-
437
- if len(boxes)>0:
438
- width, height = image_ori.size
439
- open_cv_image = np.array(image_ori)
440
- # Convert RGB to BGR
441
- open_cv_image = open_cv_image[:, :, ::-1].copy()
442
- box = box * [width, height, width, height]
443
- # for box in boxes:
444
- open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
445
- out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
446
-
447
-
448
- # output_token = outputs[0, input_ids.shape[1]:]
449
- # output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
450
- # conv[-1]["value"] = output_text
451
- # # conv.messages[-1][1] = output_text
452
- # print(
453
- # f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
454
- output_text = self.tokenizer.decode(text_outputs[0])
455
- output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
 
 
 
456
 
457
  return output_text, out_image
458
 
 
19
  from huggingface_hub import hf_hub_download, login
20
 
21
  from open_flamingo.src.factory import create_model_and_transforms
22
+ from open_flamingo.eval.task.caption import captioner
23
 
24
  class SeparatorStyle(Enum):
25
  """Different separator style."""
 
404
  image_start_index_list = [[x] for x in image_start_index_list]
405
  image_nums = [1] * len(input_ids)
406
  added_bbox_list = []
407
+ if radio in ["Cap"]:
408
+ output_text, out_image = captioner(self.model,self.tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list)
409
+ else:
410
+ with torch.inference_mode():
411
+ text_outputs = self.model.generate(
412
+ batch_images,
413
+ input_ids,
414
+ attention_mask=attention_mask,
415
+ max_new_tokens=20,
416
+ # min_new_tokens=8,
417
+ num_beams=1,
418
+ # length_penalty=0,
419
+ image_start_index_list=image_start_index_list,
420
+ image_nums=image_nums,
421
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
422
+ )
423
+ # and torch.cuda.amp.autocast(dtype=torch.float16)
424
+ with torch.no_grad():
425
+ outputs = self.model(
426
+ vision_x=batch_images,
427
+ lang_x=input_ids,
428
+ attention_mask=attention_mask,
429
+ image_nums=image_nums,
430
+ image_start_index_list=image_start_index_list,
431
+ added_bbox_list=None,
432
+ add_box=False,
433
+ )
434
+ boxes = outputs["boxes"]
435
+ scores = outputs["scores"]
436
+ if len(scores) > 0:
437
+ box = boxes[scores.argmax()] / 224
438
+ print(f"{box}")
439
+ out_image = None
440
+
441
+ if len(boxes)>0:
442
+ width, height = image_ori.size
443
+ open_cv_image = np.array(image_ori)
444
+ # Convert RGB to BGR
445
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
446
+ box = box * [width, height, width, height]
447
+ # for box in boxes:
448
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
449
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
450
+
451
+
452
+ # output_token = outputs[0, input_ids.shape[1]:]
453
+ # output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
454
+ # conv[-1]["value"] = output_text
455
+ # # conv.messages[-1][1] = output_text
456
+ # print(
457
+ # f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
458
+ output_text = self.tokenizer.decode(text_outputs[0])
459
+ output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
460
 
461
  return output_text, out_image
462
 
multimodal/open_flamingo/eval/task/caption.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  import time
8
  import os
9
  from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
10
-
11
 
12
  class VisualLogitsProcessor(LogitsProcessor):
13
  def __init__(self, tokenizer):
@@ -51,6 +51,136 @@ def prepare_batch_images(batch, image_processor):
51
  return batch_images
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def evaluate_coco_flickr(
55
  model,
56
  tokenizer,
@@ -94,6 +224,7 @@ def evaluate_coco_flickr(
94
  if ii % world_size != rank:
95
  continue
96
  cnt += len(batch)
 
97
  batch_images = prepare_batch_images(
98
  batch=batch,
99
  image_processor=image_processor,
@@ -194,13 +325,14 @@ def evaluate_coco_flickr(
194
  if debug:
195
  print("after inserting visual---->", prompt)
196
  else:
197
- # import numpy as np
198
- # import cv2
199
- # open_cv_image = np.array(batch[0]["image"])
200
- # open_cv_image = open_cv_image[:, :, ::-1].copy()
201
- # for pre_box in boxes:
202
- # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
203
- # cv2.imwrite("Atest.png", open_cv_image)
 
204
  pre_box = boxes[scores.argmax()]
205
  added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
206
  prompt = prompt[:-len(tokenizer.eos_token)]
@@ -225,6 +357,8 @@ def evaluate_coco_flickr(
225
  predictions[int(sample["image_id"])] = {
226
  "caption": new_predictions[i],
227
  }
 
 
228
  results_path = (
229
  f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
230
  if is_flickr
 
7
  import time
8
  import os
9
  from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
10
+ from PIL import Image
11
 
12
  class VisualLogitsProcessor(LogitsProcessor):
13
  def __init__(self, tokenizer):
 
51
  return batch_images
52
 
53
 
54
+ def captioner(
55
+ model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
56
+ """Evaluate a model on COCO dataset.
57
+ Returns:
58
+ float: CIDEr score
59
+
60
+ """
61
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
62
+ model.eval()
63
+ # model.eval().cuda()
64
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
65
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
66
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
67
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
68
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
69
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
70
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
71
+ box_token = "<|#box#|>"
72
+ prebox_token = "<|#prebox#|>"
73
+ endofobject_token = "<|#endofobject#|>"
74
+ object_token = "<|#object#|>"
75
+ ori_prompt_length = len(input_ids[0])
76
+ have_prebox = False
77
+ while True:
78
+ batch_images = batch_images
79
+ input_ids = input_ids
80
+ attention_mask = attention_mask
81
+ image_start_index_list = image_start_index_list
82
+ image_nums = image_nums
83
+ if debug:
84
+ print("input--->",tokenizer.decode(input_ids[0]))
85
+ p1 = MinNewTokensLengthLogitsProcessor(
86
+ prompt_length_to_skip=input_ids.shape[-1],
87
+ min_new_tokens=5,
88
+ eos_token_id=bos_token_id,
89
+ )
90
+ with torch.inference_mode():
91
+ outputs = model.generate(
92
+ batch_images,
93
+ input_ids,
94
+ attention_mask=attention_mask,
95
+ max_new_tokens=20,
96
+ # min_new_tokens=8,
97
+ num_beams=1,
98
+ # length_penalty=0,
99
+ image_start_index_list=image_start_index_list,
100
+ image_nums=image_nums,
101
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
102
+ logits_processor_list=[p1, visual_logits_processor],
103
+ )
104
+ if debug:
105
+ print("outputs--->",tokenizer.decode(outputs[0]))
106
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
107
+ prompt = tokenizer.decode(outputs.clone()[0])
108
+ is_visual = (outputs[0, -2] == visual_token_id)
109
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
110
+ encodings = tokenizer(
111
+ batch_text,
112
+ padding="longest",
113
+ truncation=True,
114
+ return_tensors="pt",
115
+ max_length=2000,
116
+ )
117
+ input_ids = encodings["input_ids"]
118
+ attention_mask = encodings["attention_mask"]
119
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
120
+ image_start_index_list = [[x] for x in image_start_index_list]
121
+ image_nums = [1] * len(input_ids)
122
+ if debug:
123
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
124
+ with torch.no_grad():
125
+ outputs = model(
126
+ vision_x=batch_images,
127
+ lang_x=input_ids,
128
+ attention_mask=attention_mask,
129
+ image_nums=image_nums,
130
+ image_start_index_list=image_start_index_list,
131
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
132
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
133
+ )
134
+ boxes = outputs["boxes"]
135
+ scores = outputs["scores"]
136
+ # if not model.valid:
137
+ # import pdb; pdb.set_trace()
138
+ if boxes is not None:
139
+ if is_visual:
140
+ if have_prebox:
141
+ added_bbox_list.pop()
142
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
143
+ have_prebox = False
144
+ if debug:
145
+ print("find previsual and remove it--->", prompt)
146
+ first_box = boxes[scores.argmax()]
147
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
148
+ prompt = prompt[:-len(tokenizer.eos_token)]
149
+ prompt += box_token + endofobject_token
150
+ if debug:
151
+ print("after inserting visual---->", prompt)
152
+ else:
153
+ import numpy as np
154
+ import cv2
155
+ open_cv_image = np.array(image_ori)
156
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
157
+ for i, pre_box in enumerate(boxes):
158
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
159
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
160
+ # exit()
161
+ pre_box = boxes[scores.argmax()]
162
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
163
+ prompt = prompt[:-len(tokenizer.eos_token)]
164
+ prompt += prebox_token + object_token
165
+ have_prebox = True
166
+ if debug:
167
+ print("after inserting previsual---->", prompt)
168
+ else:
169
+ if debug:
170
+ import pdb;pdb.set_trace()
171
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
172
+ else:
173
+ break
174
+ outputs = outputs[:, ori_prompt_length:]
175
+ outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "")
176
+ # new_predictions = [
177
+ # postprocess_captioning_generation(out).replace('"', "")
178
+ # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
179
+ # ]
180
+ # import pdb; pdb.set_trace()
181
+ return outputs, out_image
182
+
183
+
184
  def evaluate_coco_flickr(
185
  model,
186
  tokenizer,
 
224
  if ii % world_size != rank:
225
  continue
226
  cnt += len(batch)
227
+ batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
228
  batch_images = prepare_batch_images(
229
  batch=batch,
230
  image_processor=image_processor,
 
325
  if debug:
326
  print("after inserting visual---->", prompt)
327
  else:
328
+ import numpy as np
329
+ import cv2
330
+ open_cv_image = np.array(batch[0]["image"])
331
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
332
+ for i, pre_box in enumerate(boxes):
333
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
334
+ cv2.imwrite("Atest.png", open_cv_image)
335
+ exit()
336
  pre_box = boxes[scores.argmax()]
337
  added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
338
  prompt = prompt[:-len(tokenizer.eos_token)]
 
357
  predictions[int(sample["image_id"])] = {
358
  "caption": new_predictions[i],
359
  }
360
+ print(new_predictions)
361
+ exit()
362
  results_path = (
363
  f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
364
  if is_flickr