chendl commited on
Commit
af80823
1 Parent(s): 42f12c1

update cap

Browse files
app.py CHANGED
@@ -313,13 +313,16 @@ with gr.Blocks() as demo:
313
  # value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.')
314
  # text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image first, then input...", lines=3,
315
  # value=None, visible=False, interactive=False)
316
-
317
  text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...',
318
  interactive=False)
 
319
 
320
  upload_button.click(upload_img, [image, text_input, chat_state, chatbot],
321
  [image, text_input, upload_button, chat_state, img_list, chatbot])
322
-
 
 
323
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then(
324
  gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
325
  [text_input, chatbot, chat_state, img_list]
 
313
  # value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.')
314
  # text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image first, then input...", lines=3,
315
  # value=None, visible=False, interactive=False)
316
+ # with gr.Row():
317
  text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...',
318
  interactive=False)
319
+ # submit_button = gr.Button(value="Submit", interactive=True, variant="primary")
320
 
321
  upload_button.click(upload_img, [image, text_input, chat_state, chatbot],
322
  [image, text_input, upload_button, chat_state, img_list, chatbot])
323
+ # submit_button.click(gradio_ask, [text_input, chatbot, chat_state,radio], [chatbot, chat_state]).then(
324
+ # gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature], [text_input,chatbot, chat_state, img_list]
325
+ # )
326
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then(
327
  gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
328
  [text_input, chatbot, chat_state, img_list]
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -319,7 +319,8 @@ class Chat:
319
  # else:
320
  # conv.append_message(conv.roles[0], text)
321
 
322
- def answer(self, conv, img_list, radio, text_input, model_name, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
 
323
  repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
324
  # conv.append_message(conv.roles[1], None)
325
  # embs = self.get_context_emb(conv, img_list)
@@ -424,7 +425,7 @@ class Chat:
424
  if radio in ["Cap"]:
425
  output_text, out_image = captioner(self.model, self.tokenizer, image_ori, batch_images, input_ids,
426
  attention_mask, image_start_index_list, image_nums, added_bbox_list)
427
-
428
  else:
429
  with torch.inference_mode():
430
  text_outputs = self.model.generate(
@@ -477,7 +478,6 @@ class Chat:
477
  print(output_text)
478
  output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
479
  print(output_text)
480
- print(output_text)
481
 
482
  return output_text, out_image
483
 
 
319
  # else:
320
  # conv.append_message(conv.roles[0], text)
321
 
322
+ def answer(self, conv, img_list, radio, text_input, model_name, max_new_tokens=200, num_beams=5, min_length=1,
323
+ top_p=0.9,
324
  repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
325
  # conv.append_message(conv.roles[1], None)
326
  # embs = self.get_context_emb(conv, img_list)
 
425
  if radio in ["Cap"]:
426
  output_text, out_image = captioner(self.model, self.tokenizer, image_ori, batch_images, input_ids,
427
  attention_mask, image_start_index_list, image_nums, added_bbox_list)
428
+ print("asdfghkl----------------------------------------------------------------------------------------->")
429
  else:
430
  with torch.inference_mode():
431
  text_outputs = self.model.generate(
 
478
  print(output_text)
479
  output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
480
  print(output_text)
 
481
 
482
  return output_text, out_image
483
 
multimodal/open_flamingo/eval/task/caption_chat.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import torch
3
  import more_itertools
4
  from tqdm import tqdm
@@ -8,6 +7,7 @@ import os
8
  from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
9
  from PIL import Image
10
 
 
11
  class VisualLogitsProcessor(LogitsProcessor):
12
  def __init__(self, tokenizer):
13
  super().__init__()
@@ -24,7 +24,10 @@ class VisualLogitsProcessor(LogitsProcessor):
24
  def __call__(self, input_ids, scores):
25
  # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
26
  # import pdb; pdb.set_trace()
27
- if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
 
 
 
28
  scores[0, self.object_token_id] = 1000
29
  if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
30
  if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
@@ -75,7 +78,9 @@ def captioner(
75
  ori_prompt_length = len(input_ids[0])
76
  have_prebox = False
77
  prompt = None
78
- while True:
 
 
79
  batch_images = batch_images
80
  if prompt == None:
81
  input_ids = input_ids
@@ -167,12 +172,7 @@ def captioner(
167
  else:
168
  import numpy as np
169
  import cv2
170
- open_cv_image = np.array(image_ori)
171
- open_cv_image = open_cv_image[:, :, ::-1].copy()
172
- for i, pre_box in enumerate(boxes):
173
- open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int),
174
- (0, 255, 0), i + 1)
175
- out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
176
  # exit()
177
  pre_box = boxes[scores.argmax()]
178
  added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
@@ -190,253 +190,22 @@ def captioner(
190
  prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
191
  if debug:
192
  print("after else---->", prompt)
193
-
194
-
195
  else:
196
- break
197
  outputs = outputs[:, ori_prompt_length:]
198
  outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
 
 
 
 
 
 
199
  # new_predictions = [
200
  # postprocess_captioning_generation(out).replace('"', "")
201
  # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
202
  # ]
203
  # import pdb; pdb.set_trace()
204
- print("out----------------------------------------------------------------------------------------->")
205
- return outputs, out_image
206
-
207
-
208
- def evaluate_coco_flickr(
209
- model,
210
- tokenizer,
211
- image_processor,
212
- batch_size,
213
- is_flickr=False,
214
- vis_embed_size=None,
215
- rank=0,
216
- world_size=1,
217
- id=0,
218
- debug=False,
219
- ):
220
- """Evaluate a model on COCO dataset.
221
- Returns:
222
- float: CIDEr score
223
-
224
- """
225
- visual_logits_processor = VisualLogitsProcessor(tokenizer)
226
- coco_dataset = load_dataset("coco_caption")
227
- eval_dataset = coco_dataset["test"]
228
- model.eval().cuda()
229
- predictions = dict()
230
- lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
231
- media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
232
- endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
233
- pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
234
- bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
235
- previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
236
- visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
237
- box_token = "<|#box#|>"
238
- prebox_token = "<|#prebox#|>"
239
- endofobject_token = "<|#endofobject#|>"
240
- object_token = "<|#object#|>"
241
- cnt = 0
242
- if world_size > 1:
243
- torch.distributed.barrier()
244
- desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
245
- for ii, batch in enumerate(more_itertools.chunked(
246
- tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
247
- )):
248
- if ii % world_size != rank:
249
- continue
250
- cnt += len(batch)
251
- batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
252
- batch_images = prepare_batch_images(
253
- batch=batch,
254
- image_processor=image_processor,
255
- ).cuda()
256
- prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
257
- added_bbox_list = []
258
- batch_text = [prompt for _ in batch]
259
- encodings = tokenizer(
260
- batch_text,
261
- padding="longest",
262
- truncation=True,
263
- return_tensors="pt",
264
- max_length=2000,
265
- )
266
- ori_prompt_length = len(encodings["input_ids"][0])
267
- have_prebox = False
268
- while True:
269
- batch_text = [prompt for _ in batch]
270
- encodings = tokenizer(
271
- batch_text,
272
- padding="longest",
273
- truncation=True,
274
- return_tensors="pt",
275
- max_length=2000,
276
- )
277
- input_ids = encodings["input_ids"].cuda()
278
- attention_mask = encodings["attention_mask"].cuda()
279
- image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
280
- image_start_index_list = [[x] for x in image_start_index_list]
281
- image_nums = [1] * len(input_ids)
282
- if debug:
283
- print("input--->",tokenizer.decode(input_ids[0]))
284
- p1 = MinNewTokensLengthLogitsProcessor(
285
- prompt_length_to_skip=input_ids.shape[-1],
286
- min_new_tokens=5,
287
- eos_token_id=bos_token_id,
288
- )
289
- with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
290
- outputs = model.generate(
291
- batch_images,
292
- input_ids,
293
- attention_mask=attention_mask,
294
- max_new_tokens=20,
295
- # min_new_tokens=8,
296
- num_beams=1,
297
- # length_penalty=0,
298
- image_start_index_list=image_start_index_list,
299
- image_nums=image_nums,
300
- added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
301
- logits_processor_list=[p1, visual_logits_processor],
302
- )
303
- if debug:
304
- print("outputs--->",tokenizer.decode(outputs[0]))
305
- if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
306
- prompt = tokenizer.decode(outputs.clone()[0])
307
- is_visual = (outputs[0, -2] == visual_token_id)
308
- batch_text = tokenizer.batch_decode(outputs[:, :-1])
309
- encodings = tokenizer(
310
- batch_text,
311
- padding="longest",
312
- truncation=True,
313
- return_tensors="pt",
314
- max_length=2000,
315
- )
316
- input_ids = encodings["input_ids"].cuda()
317
- attention_mask = encodings["attention_mask"].cuda()
318
- image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
319
- image_start_index_list = [[x] for x in image_start_index_list]
320
- image_nums = [1] * len(input_ids)
321
- if debug:
322
- print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
323
- with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
324
- outputs = model(
325
- vision_x=batch_images,
326
- lang_x=input_ids,
327
- attention_mask=attention_mask,
328
- image_nums=image_nums,
329
- image_start_index_list=image_start_index_list,
330
- added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
331
- add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
332
- )
333
- boxes = outputs["boxes"]
334
- scores = outputs["scores"]
335
- # if not model.valid:
336
- # import pdb; pdb.set_trace()
337
- if boxes is not None:
338
- if is_visual:
339
- if have_prebox:
340
- added_bbox_list.pop()
341
- prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
342
- have_prebox = False
343
- if debug:
344
- print("find previsual and remove it--->", prompt)
345
- first_box = boxes[scores.argmax()]
346
- added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
347
- prompt = prompt[:-len(tokenizer.eos_token)]
348
- prompt += box_token + endofobject_token
349
- if debug:
350
- print("after inserting visual---->", prompt)
351
- else:
352
- import numpy as np
353
- import cv2
354
- open_cv_image = np.array(batch[0]["image"])
355
- open_cv_image = open_cv_image[:, :, ::-1].copy()
356
- for i, pre_box in enumerate(boxes):
357
- open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
358
- cv2.imwrite("Atest.png", open_cv_image)
359
- exit()
360
- pre_box = boxes[scores.argmax()]
361
- added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
362
- prompt = prompt[:-len(tokenizer.eos_token)]
363
- prompt += prebox_token + object_token
364
- have_prebox = True
365
- if debug:
366
- print("after inserting previsual---->", prompt)
367
- else:
368
- import pdb;pdb.set_trace()
369
- prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
370
- else:
371
- break
372
- outputs = outputs[:, ori_prompt_length:]
373
- new_predictions = [
374
- postprocess_captioning_generation(out).replace('"', "")
375
- for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
376
- ]
377
- # import pdb; pdb.set_trace()
378
- if rank == 0:
379
- tqdm.write(new_predictions[0])
380
- for i, sample in enumerate(batch):
381
- predictions[int(sample["image_id"])] = {
382
- "caption": new_predictions[i],
383
- }
384
- print(new_predictions)
385
- exit()
386
- results_path = (
387
- f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
388
- if is_flickr
389
- else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
390
- )
391
- with open(results_path, "w") as f:
392
- f.write(
393
- json.dumps(
394
- [
395
- {"image_id": k, "caption": predictions[k]["caption"]}
396
- for k in predictions
397
- ],
398
- indent=2,
399
- )
400
- )
401
- print("save to", results_path)
402
- del predictions
403
- time.sleep(10)
404
- if world_size > 1:
405
- torch.distributed.barrier()
406
- if rank == 0:
407
- print(f"evaluate on rank {rank}. world size is {world_size}")
408
- predictions = []
409
- for rank_i in range(world_size):
410
- part_results_path = (
411
- f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
412
- if is_flickr
413
- else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
414
- )
415
- print("load", part_results_path)
416
- predictions.extend(json.load(open(part_results_path)))
417
- os.remove(part_results_path)
418
- print("num:", len(predictions))
419
- results_path = (
420
- f"flickrresults_{lang_encoder_name}.json"
421
- if is_flickr
422
- else f"cocoresults_{lang_encoder_name}.json"
423
- )
424
- json.dump(predictions, open(results_path, "w"), indent=2)
425
 
426
- metrics = compute_cider(
427
- result_path=results_path,
428
- annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
429
- )
430
- metrics["CIDEr"] *= 100
431
- os.makedirs("eval_results", exist_ok=True)
432
- acc = metrics["CIDEr"]
433
- with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
434
- f.write(json.dumps(predictions, indent=2))
435
 
436
- # delete the temporary file
437
- os.remove(results_path)
438
- else:
439
- metrics = {}
440
- metrics["CIDEr"] = 0.0
441
 
442
- return metrics["CIDEr"]
 
 
1
  import torch
2
  import more_itertools
3
  from tqdm import tqdm
 
7
  from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
8
  from PIL import Image
9
 
10
+
11
  class VisualLogitsProcessor(LogitsProcessor):
12
  def __init__(self, tokenizer):
13
  super().__init__()
 
24
  def __call__(self, input_ids, scores):
25
  # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
26
  # import pdb; pdb.set_trace()
27
+ if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][
28
+ 1:self.topk] and self.eos_token_id not in \
29
+ scores.sort(descending=True).indices.tolist()[0][:self.topk] and (
30
+ input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
31
  scores[0, self.object_token_id] = 1000
32
  if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
33
  if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
 
78
  ori_prompt_length = len(input_ids[0])
79
  have_prebox = False
80
  prompt = None
81
+ out_image = None
82
+ no_end = True
83
+ while no_end:
84
  batch_images = batch_images
85
  if prompt == None:
86
  input_ids = input_ids
 
172
  else:
173
  import numpy as np
174
  import cv2
175
+
 
 
 
 
 
176
  # exit()
177
  pre_box = boxes[scores.argmax()]
178
  added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
 
190
  prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
191
  if debug:
192
  print("after else---->", prompt)
 
 
193
  else:
194
+ no_end = False
195
  outputs = outputs[:, ori_prompt_length:]
196
  outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
197
+ open_cv_image = np.array(image_ori)
198
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
199
+ for i, pre_box in enumerate(added_bbox_list):
200
+ open_cv_image = cv2.rectangle(open_cv_image, (pre_box[:2] * 224).astype(int), (pre_box[2:] * 224).astype(int),
201
+ (0, 255, 0), i + 1)
202
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
203
  # new_predictions = [
204
  # postprocess_captioning_generation(out).replace('"', "")
205
  # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
206
  # ]
207
  # import pdb; pdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ return outputs, out_image
 
 
 
 
 
 
 
 
210
 
 
 
 
 
 
211