chendl commited on
Commit
e770d90
1 Parent(s): e9a2cd8

update cap

Browse files
Files changed (50) hide show
  1. app.py +3 -2
  2. multimodal/build/lib/open_flamingo/__init__.py +2 -0
  3. multimodal/build/lib/open_flamingo/chat/__init__.py +0 -0
  4. multimodal/build/lib/open_flamingo/chat/conversation.py +571 -0
  5. multimodal/build/lib/open_flamingo/eval/__init__.py +1 -0
  6. multimodal/build/lib/open_flamingo/eval/classification.py +147 -0
  7. multimodal/build/lib/open_flamingo/eval/coco_metric.py +23 -0
  8. multimodal/build/lib/open_flamingo/eval/dataset_zoo/__init__.py +33 -0
  9. multimodal/build/lib/open_flamingo/eval/dataset_zoo/aro_datasets.py +365 -0
  10. multimodal/build/lib/open_flamingo/eval/dataset_zoo/constants.py +3 -0
  11. multimodal/build/lib/open_flamingo/eval/dataset_zoo/perturbations.py +194 -0
  12. multimodal/build/lib/open_flamingo/eval/dataset_zoo/retrieval.py +266 -0
  13. multimodal/build/lib/open_flamingo/eval/dataset_zoo/utils.py +15 -0
  14. multimodal/build/lib/open_flamingo/eval/eval_datasets.py +101 -0
  15. multimodal/build/lib/open_flamingo/eval/evaluate.py +1435 -0
  16. multimodal/build/lib/open_flamingo/eval/evaluate_debug.py +1159 -0
  17. multimodal/build/lib/open_flamingo/eval/evaluate_find_showcase.py +1700 -0
  18. multimodal/build/lib/open_flamingo/eval/evaluate_temp.py +1838 -0
  19. multimodal/build/lib/open_flamingo/eval/imagenet_utils.py +1007 -0
  20. multimodal/build/lib/open_flamingo/eval/ok_vqa_utils.py +213 -0
  21. multimodal/build/lib/open_flamingo/eval/task/__init__.py +0 -0
  22. multimodal/build/lib/open_flamingo/eval/task/caption.py +419 -0
  23. multimodal/build/lib/open_flamingo/eval/task/caption_chat.py +417 -0
  24. multimodal/build/lib/open_flamingo/eval/task/cola.py +220 -0
  25. multimodal/build/lib/open_flamingo/eval/task/crepe.py +93 -0
  26. multimodal/build/lib/open_flamingo/eval/task/gqa.py +248 -0
  27. multimodal/build/lib/open_flamingo/eval/task/mmbench.py +84 -0
  28. multimodal/build/lib/open_flamingo/eval/task/reg.py +141 -0
  29. multimodal/build/lib/open_flamingo/eval/task/utils.py +287 -0
  30. multimodal/build/lib/open_flamingo/eval/task/vl_checklist.py +113 -0
  31. multimodal/build/lib/open_flamingo/eval/vqa_metric.py +594 -0
  32. multimodal/build/lib/open_flamingo/src/__init__.py +0 -0
  33. multimodal/build/lib/open_flamingo/src/attention.py +45 -0
  34. multimodal/build/lib/open_flamingo/src/factory.py +269 -0
  35. multimodal/build/lib/open_flamingo/src/flamingo.py +637 -0
  36. multimodal/build/lib/open_flamingo/src/flamingo_lm.py +173 -0
  37. multimodal/build/lib/open_flamingo/src/gcn.py +137 -0
  38. multimodal/build/lib/open_flamingo/src/helpers.py +263 -0
  39. multimodal/build/lib/open_flamingo/src/utils.py +31 -0
  40. multimodal/build/lib/open_flamingo/train/__init__.py +1 -0
  41. multimodal/build/lib/open_flamingo/train/data2.py +868 -0
  42. multimodal/build/lib/open_flamingo/train/distributed.py +128 -0
  43. multimodal/build/lib/open_flamingo/train/instruction_template.py +13 -0
  44. multimodal/build/lib/open_flamingo/train/train.py +709 -0
  45. multimodal/build/lib/open_flamingo/train/train_utils.py +387 -0
  46. multimodal/open_flamingo.egg-info/PKG-INFO +247 -0
  47. multimodal/open_flamingo.egg-info/SOURCES.txt +53 -0
  48. multimodal/open_flamingo.egg-info/dependency_links.txt +1 -0
  49. multimodal/open_flamingo.egg-info/requires.txt +17 -0
  50. multimodal/open_flamingo.egg-info/top_level.txt +1 -0
app.py CHANGED
@@ -53,7 +53,8 @@ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transfor
53
  )
54
 
55
 
56
- checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt")
 
57
  checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
58
  model_state_dict = {}
59
  for key in checkpoint.keys():
@@ -326,7 +327,7 @@ with gr.Blocks() as demo:
326
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
327
  queue=False)
328
 
329
- demo.launch(enable_queue=True)
330
  #
331
  # with gr.Blocks() as demo:
332
  # gr.Markdown(
 
53
  )
54
 
55
 
56
+ checkpoint_path = "/home/aimos/huggingface/space/demo.pt"
57
+ # hf_hub_download("chendl/compositional_test", "pythiaS.pt")
58
  checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
59
  model_state_dict = {}
60
  for key in checkpoint.keys():
 
327
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
328
  queue=False)
329
 
330
+ demo.launch(enable_queue=True,share=True)
331
  #
332
  # with gr.Blocks() as demo:
333
  # gr.Markdown(
multimodal/build/lib/open_flamingo/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .src.flamingo import Flamingo
2
+ from .src.factory import create_model_and_transforms
multimodal/build/lib/open_flamingo/chat/__init__.py ADDED
File without changes
multimodal/build/lib/open_flamingo/chat/conversation.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import re
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import numpy as np
8
+ import transformers
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
10
+ from transformers import StoppingCriteria, StoppingCriteriaList
11
+
12
+ import dataclasses
13
+ from enum import auto, Enum
14
+ from typing import List, Tuple, Any
15
+
16
+ import string
17
+ import cv2
18
+ import gradio as gr
19
+
20
+ from huggingface_hub import hf_hub_download, login
21
+
22
+ from open_flamingo.src.factory import create_model_and_transforms
23
+ from open_flamingo.eval.task.caption_chat import captioner
24
+
25
+ class SeparatorStyle(Enum):
26
+ """Different separator style."""
27
+ SINGLE = auto()
28
+ TWO = auto()
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class Conversation:
33
+ """A class that keeps all conversation history."""
34
+ system: str
35
+ roles: List[str]
36
+ messages: List[List[str]]
37
+ offset: int
38
+ # system_img: List[Image.Image] = []
39
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
40
+ sep: str = "###"
41
+ sep2: str = None
42
+
43
+ skip_next: bool = False
44
+ conv_id: Any = None
45
+
46
+ def get_prompt(self):
47
+ if self.sep_style == SeparatorStyle.SINGLE:
48
+ ret = self.system + self.sep
49
+ for role, message in self.messages:
50
+ if message:
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(self.messages):
59
+ if message:
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ return ret
64
+ else:
65
+ raise ValueError(f"Invalid style: {self.sep_style}")
66
+
67
+ def append_message(self, role, message):
68
+ self.messages.append([role, message])
69
+
70
+ def to_gradio_chatbot(self):
71
+ ret = []
72
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
73
+ if i % 2 == 0:
74
+ ret.append([msg, None])
75
+ else:
76
+ ret[-1][-1] = msg
77
+ return ret
78
+
79
+ def copy(self):
80
+ return Conversation(
81
+ system=self.system,
82
+ # system_img=self.system_img,
83
+ roles=self.roles,
84
+ messages=[[x, y] for x, y in self.messages],
85
+ offset=self.offset,
86
+ sep_style=self.sep_style,
87
+ sep=self.sep,
88
+ sep2=self.sep2,
89
+ conv_id=self.conv_id)
90
+
91
+ def dict(self):
92
+ return {
93
+ "system": self.system,
94
+ # "system_img": self.system_img,
95
+ "roles": self.roles,
96
+ "messages": self.messages,
97
+ "offset": self.offset,
98
+ "sep": self.sep,
99
+ "sep2": self.sep2,
100
+ "conv_id": self.conv_id,
101
+ }
102
+
103
+
104
+ class StoppingCriteriaSub(StoppingCriteria):
105
+
106
+ def __init__(self, stops=[], encounters=1):
107
+ super().__init__()
108
+ self.stops = stops
109
+
110
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
111
+ for stop in self.stops:
112
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
113
+ return True
114
+
115
+ return False
116
+
117
+
118
+ CONV_VISION = Conversation(
119
+ system="Give the following image: <Img>ImageContent</Img>. "
120
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
121
+ roles=("Human", "Assistant"),
122
+ messages=[],
123
+ offset=2,
124
+ sep_style=SeparatorStyle.SINGLE,
125
+ sep="###",
126
+ )
127
+
128
+ def get_outputs(
129
+ model,
130
+ batch_images,
131
+ attention_mask,
132
+ max_generation_length,
133
+ min_generation_length,
134
+ num_beams,
135
+ length_penalty,
136
+ input_ids,
137
+ image_start_index_list=None,
138
+ image_nums=None,
139
+ bad_words_ids=None,
140
+ ):
141
+ # and torch.cuda.amp.autocast(dtype=torch.float16)
142
+ with torch.inference_mode():
143
+ outputs = model(
144
+ vision_x=batch_images,
145
+ lang_x=input_ids,
146
+ attention_mask=attention_mask,
147
+ labels=None,
148
+ image_nums=image_nums,
149
+ image_start_index_list=image_start_index_list,
150
+ added_bbox_list=None,
151
+ add_box=False,
152
+ )
153
+ # outputs = model.generate(
154
+ # batch_images,
155
+ # input_ids,
156
+ # attention_mask=attention_mask,
157
+ # max_new_tokens=max_generation_length,
158
+ # min_length=min_generation_length,
159
+ # num_beams=num_beams,
160
+ # length_penalty=length_penalty,
161
+ # image_start_index_list=image_start_index_list,
162
+ # image_nums=image_nums,
163
+ # bad_words_ids=bad_words_ids,
164
+ # )
165
+
166
+ return outputs
167
+
168
+ def generate(
169
+ idx,
170
+ image,
171
+ text,
172
+ image_processor,
173
+ tokenizer,
174
+ flamingo,
175
+ vis_embed_size=256,
176
+ rank=0,
177
+ world_size=1,
178
+ ):
179
+ if image is None:
180
+ raise gr.Error("Please upload an image.")
181
+ flamingo.eval()
182
+ loc_token_ids = []
183
+ for i in range(1000):
184
+ loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
185
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
186
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
187
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
188
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
189
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
190
+
191
+ image_ori = image
192
+ image = image.convert("RGB")
193
+ width = image.width
194
+ height = image.height
195
+ image = image.resize((224, 224))
196
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
197
+ if idx == 1:
198
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
199
+ bad_words_ids = None
200
+ max_generation_length = 5
201
+ else:
202
+ prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
203
+ bad_words_ids = loc_word_ids
204
+ max_generation_length = 300
205
+ encodings = tokenizer(
206
+ prompt,
207
+ padding="longest",
208
+ truncation=True,
209
+ return_tensors="pt",
210
+ max_length=2000,
211
+ )
212
+ input_ids = encodings["input_ids"]
213
+ attention_mask = encodings["attention_mask"]
214
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
215
+ image_start_index_list = [[x] for x in image_start_index_list]
216
+ image_nums = [1] * len(input_ids)
217
+ outputs = get_outputs(
218
+ model=flamingo,
219
+ batch_images=batch_images,
220
+ attention_mask=attention_mask,
221
+ max_generation_length=max_generation_length,
222
+ min_generation_length=4,
223
+ num_beams=1,
224
+ length_penalty=1.0,
225
+ input_ids=input_ids,
226
+ bad_words_ids=bad_words_ids,
227
+ image_start_index_list=image_start_index_list,
228
+ image_nums=image_nums,
229
+ )
230
+
231
+ boxes = outputs["boxes"]
232
+ scores = outputs["scores"]
233
+ if len(scores) > 0:
234
+ box = boxes[scores.argmax()]/224
235
+ print(f"{box}")
236
+
237
+
238
+ if len(boxes)>0:
239
+ open_cv_image = np.array(image_ori)
240
+ # Convert RGB to BGR
241
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
242
+ box = box*[width,height,width,height]
243
+ # for box in boxes:
244
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
245
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
246
+ return f"Output:{box}", out_image
247
+ else:
248
+ gen_text = tokenizer.batch_decode(outputs)
249
+ return (f"{gen_text}")
250
+
251
+ def preprocess_conv(data):
252
+ conversation = ""
253
+ BEGIN_SIGNAL = "### "
254
+ END_SIGNAL = "\n"
255
+ for idx, d in enumerate(data):
256
+ from_str = d["from"]
257
+ if from_str.lower() == "human":
258
+ from_str = "Human"
259
+ elif from_str.lower() == "gpt":
260
+ from_str = "Assistant"
261
+ else:
262
+ from_str = 'unknown'
263
+ conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
264
+ return conversation
265
+
266
+ def preprocess_image(sample, image_processor):
267
+ image = image_processor(sample)
268
+ if isinstance(image, transformers.image_processing_utils.BatchFeature):
269
+ image = torch.tensor(image["pixel_values"][0])
270
+ return image
271
+
272
+ class Chat:
273
+ def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
274
+ self.model = model
275
+ self.vis_processor = vis_processor
276
+ self.tokenizer = tokenizer
277
+ self.vis_embed_size = vis_embed_size
278
+ self.conv = []
279
+ # stop_words_ids = [torch.tensor([835]).to(self.device),
280
+ # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
281
+ # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
282
+
283
+ def ask(self, text, conv,radio):
284
+ if radio in ["Cap"]:
285
+ conv.append({
286
+ "from": "human",
287
+ "value": "",
288
+ })
289
+ elif radio in ["VQA"]:
290
+ conv.append({
291
+ "from": "human",
292
+ "value": f"Answer the question using a single word or phrase. {text}",
293
+ })
294
+ elif radio in ["REC"]:
295
+ conv.append({
296
+ "from": "human",
297
+ "value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
298
+ })
299
+ else:
300
+ conv.append({
301
+ "from": "human",
302
+ "value": text,
303
+ })
304
+ # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
305
+ # and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
306
+ # conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
307
+ # else:
308
+ # conv.append_message(conv.roles[0], text)
309
+
310
+ def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
311
+ repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
312
+ # conv.append_message(conv.roles[1], None)
313
+ # embs = self.get_context_emb(conv, img_list)
314
+ #
315
+ # # current_max_len = embs.shape[1] + max_new_tokens + 100
316
+ # # begin_idx = max(0, current_max_len - max_length)
317
+ # # embs = embs[:, begin_idx:]
318
+ # outputs = self.model.llama_model.generate(
319
+ # inputs_embeds=embs,
320
+ # max_new_tokens=max_new_tokens,
321
+ # stopping_criteria=self.stopping_criteria,
322
+ # num_beams=num_beams,
323
+ # min_length=min_length,
324
+ # top_p=top_p,
325
+ # repetition_penalty=repetition_penalty,
326
+ # length_penalty=length_penalty,
327
+ # temperature=temperature,
328
+ # )
329
+ # output_token = outputs[0]
330
+ # if output_token[0] == 0:
331
+ # output_token = output_token[1:]
332
+ # output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
333
+ # output_text = output_text.split('###')[0] # remove the stop sign '###'
334
+ # output_text = output_text.split('Assistant:')[-1].strip()
335
+ # conv.messages[-1][1] = output_text
336
+ visual_token = "<|#visual#|>"
337
+ previsual_token = "<|#previsual#|>"
338
+ box_token = "<|#box#|>"
339
+ prebox_token = "<|#prebox#|>"
340
+ end_token = "<|#endofobject#|>"
341
+ object_token = "<|#object#|>"
342
+ end_of_attr_token = "<|#endofattr#|>"
343
+ preend_of_attr_token = "<|#preendofattr#|>"
344
+ media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
345
+ box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
346
+ endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
347
+ endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
348
+ endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
349
+ visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
350
+ previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
351
+ prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
352
+ size = 224
353
+ self.model.eval()
354
+ # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
355
+ # image_path = input("Please enter the image path: ")
356
+ image = img_list[0].convert("RGB")
357
+ image_ori = image
358
+ image = image.resize((size, size))
359
+ print(f"image size: {image.size}")
360
+ batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
361
+
362
+ # conversation = []
363
+ human_sentence = None
364
+ if radio in ["Cap","VQA"]:
365
+ conv.append({
366
+ "from": "gpt",
367
+ "value": "",
368
+ })
369
+ elif radio in ["REC"]:
370
+ conv.append(
371
+ {
372
+ "from": "gpt",
373
+ "value": object_token + text_input + end_token + visual_token,
374
+ }
375
+ )
376
+ else:
377
+ conv.append({
378
+ "from": "gpt",
379
+ "value": "",
380
+ })
381
+ # while True:
382
+ # human_sentence = input("### Human: ")
383
+ # if human_sentence == "#end#":
384
+ # break
385
+ # conversation.append({
386
+ # "from": "human",
387
+ # "value": human_sentence,
388
+ # })
389
+ # conversation.append({
390
+ # "from": "gpt",
391
+ # "value": "",
392
+ # })
393
+ text = preprocess_conv(conv).strip()
394
+ caption = f"<|#image#|>{self.tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
395
+ encodings = self.tokenizer(
396
+ caption,
397
+ padding="longest",
398
+ truncation=True,
399
+ return_tensors="pt",
400
+ max_length=2000,
401
+ )
402
+ input_ids = encodings["input_ids"]
403
+ attention_mask = encodings["attention_mask"]
404
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
405
+ image_start_index_list = [[x] for x in image_start_index_list]
406
+ image_nums = [1] * len(input_ids)
407
+ added_bbox_list = []
408
+ if radio in ["Cap"]:
409
+ output_text, out_image = captioner(self.model,self.tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list)
410
+ else:
411
+ with torch.inference_mode():
412
+ text_outputs = self.model.generate(
413
+ batch_images,
414
+ input_ids,
415
+ attention_mask=attention_mask,
416
+ max_new_tokens=20,
417
+ # min_new_tokens=8,
418
+ num_beams=1,
419
+ # length_penalty=0,
420
+ image_start_index_list=image_start_index_list,
421
+ image_nums=image_nums,
422
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
423
+ )
424
+ # and torch.cuda.amp.autocast(dtype=torch.float16)
425
+ with torch.no_grad():
426
+ outputs = self.model(
427
+ vision_x=batch_images,
428
+ lang_x=input_ids,
429
+ attention_mask=attention_mask,
430
+ image_nums=image_nums,
431
+ image_start_index_list=image_start_index_list,
432
+ added_bbox_list=None,
433
+ add_box=False,
434
+ )
435
+ boxes = outputs["boxes"]
436
+ scores = outputs["scores"]
437
+ if len(scores) > 0:
438
+ box = boxes[scores.argmax()] / 224
439
+ print(f"{box}")
440
+ out_image = None
441
+
442
+ if len(boxes)>0:
443
+ width, height = image_ori.size
444
+ open_cv_image = np.array(image_ori)
445
+ # Convert RGB to BGR
446
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
447
+ box = box * [width, height, width, height]
448
+ # for box in boxes:
449
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
450
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
451
+
452
+
453
+ # output_token = outputs[0, input_ids.shape[1]:]
454
+ # output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
455
+ # conv[-1]["value"] = output_text
456
+ # # conv.messages[-1][1] = output_text
457
+ # print(
458
+ # f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
459
+ output_text = self.tokenizer.decode(text_outputs[0])
460
+ print(output_text)
461
+ output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
462
+ print(output_text)
463
+
464
+ return output_text, out_image
465
+
466
+ def upload_img(self, image, conv, img_list):
467
+ img_list.append(image)
468
+ # if isinstance(image, str): # is a image path
469
+ # raw_image = Image.open(image).convert('RGB')
470
+ # image = image.resize((224, 224))
471
+ # image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
472
+ # elif isinstance(image, Image.Image):
473
+ # raw_image = image
474
+ # image = image.resize((224, 224))
475
+ # image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
476
+ # elif isinstance(image, torch.Tensor):
477
+ # if len(image.shape) == 3:
478
+ # image = image.unsqueeze(0)
479
+ # # image = image.to(self.device)
480
+ #
481
+ # # image_emb, _ = self.model.encode_img(image)
482
+ # img_list.append(image_emb)
483
+ # conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
484
+ msg = "Received."
485
+ # self.conv.append_message(self.conv.roles[1], msg)
486
+ return msg
487
+
488
+ # def get_context_emb(self, conv, img_list):
489
+ # prompt = conv.get_prompt()
490
+ # prompt_segs = prompt.split('<ImageHere>')
491
+ # assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
492
+ # seg_tokens = [
493
+ # self.model.llama_tokenizer(
494
+ # seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
495
+ # # only add bos to the first seg
496
+ # for i, seg in enumerate(prompt_segs)
497
+ # ]
498
+ # seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
499
+ # mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
500
+ # mixed_embs = torch.cat(mixed_embs, dim=1)
501
+ # return mixed_embs
502
+
503
+ def evaluate_exp(
504
+ model,
505
+ tokenizer,
506
+ image_processor,
507
+ vis_embed_size=None,
508
+ rank=0,
509
+ world_size=1,
510
+ id=0,
511
+ add_visual=True,
512
+ ):
513
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
514
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
515
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
516
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
517
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
518
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
519
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
520
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
521
+ size = image_processor.size["shortest_edge"]
522
+ model.eval()
523
+ # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
524
+ image_path = input("Please enter the image path: ")
525
+ image = Image.open(image_path).convert("RGB")
526
+ image = image.resize((size, size))
527
+ print(f"image size: {image.size}")
528
+ batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
529
+ conversation = []
530
+ human_sentence = None
531
+ while True:
532
+ human_sentence = input("### Human: ")
533
+ if human_sentence == "#end#":
534
+ break
535
+ conversation.append({
536
+ "from": "human",
537
+ "value": human_sentence,
538
+ })
539
+ conversation.append({
540
+ "from": "gpt",
541
+ "value": "",
542
+ })
543
+ text = preprocess_conv(conversation).strip()
544
+ caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
545
+ encodings = tokenizer(
546
+ caption,
547
+ padding="longest",
548
+ truncation=True,
549
+ return_tensors="pt",
550
+ max_length=2000,
551
+ )
552
+ input_ids = encodings["input_ids"].to("cuda")
553
+ attention_mask = encodings["attention_mask"].to("cuda")
554
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
555
+ image_start_index_list = [[x] for x in image_start_index_list]
556
+ image_nums = [1] * len(input_ids)
557
+ with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
558
+ outputs = model.generate(
559
+ batch_images,
560
+ input_ids,
561
+ attention_mask=attention_mask,
562
+ max_new_tokens=100,
563
+ # min_new_tokens=8,
564
+ num_beams=1,
565
+ image_start_index_list=image_start_index_list,
566
+ image_nums=image_nums,
567
+ )
568
+ print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
569
+
570
+
571
+
multimodal/build/lib/open_flamingo/eval/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
multimodal/build/lib/open_flamingo/eval/classification.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence, Tuple
2
+ import re
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def postprocess_classification_generation(predictions) -> str:
8
+ return re.split("Prompt|Completion", predictions, 1)[0]
9
+
10
+
11
+ def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
12
+ """Compute the accuracy of a sequence of predictions."""
13
+
14
+ def _preprocess_fn(s):
15
+ """Function to preprocess both targets and predictions."""
16
+ return s.lower()
17
+
18
+ is_correct = [
19
+ _preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
20
+ for x in predictions
21
+ ]
22
+
23
+ return np.mean(is_correct).item()
24
+
25
+
26
+ def compute_shifted_logits_and_labels(
27
+ logits: torch.Tensor, encodings, tokenizer, eoc_token_id
28
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
29
+ """Helper function to compute shifted logits and labels.
30
+
31
+ This allows for straightforward computation of the loss on shift_logits
32
+ and shift_labels such that the nth element of logits computes the n-1th
33
+ element of the original labels (in the outputs, the nth element of logits
34
+ corresponds to the nth element of the labels).
35
+
36
+ Elements in shift_labels that correspond to inputs are masked with values
37
+ of -100 (by default in hf, loss is only computed on token IDs >= 0).
38
+
39
+ Returns: tuple containing two elements:
40
+ shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
41
+ shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
42
+ """
43
+
44
+ labels = encodings["input_ids"].clone()
45
+
46
+ # convert padding and EOC tokens to -100 so they are ignored in loss
47
+ labels[labels == tokenizer.pad_token_id] = -100
48
+ labels[labels == eoc_token_id] = -100
49
+
50
+ # Convert all tokens in prefix until separator to -100 so they are
51
+ # ignored in loss
52
+ for idx in range(len(labels)):
53
+ # Find the location of the last token of prefix *from right*,
54
+ # since the first non-padding token of the sequence will also be
55
+ # eos_token (because bos_token and eos_token are the same for
56
+ # the tokenizer).
57
+ end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
58
+ labels[idx, : end_of_prefix + 1] = -100
59
+
60
+ # Shift so that tokens < n predict n. The shifted tensors both have
61
+ # shape [batch_size, seq_len - 1].
62
+ shift_logits = logits[..., :-1, :].contiguous()
63
+ shift_labels = labels[..., 1:].contiguous()
64
+
65
+ return shift_logits, shift_labels
66
+
67
+
68
+ def compute_per_sample_probs(
69
+ encodings, tokenizer, logits: torch.Tensor, eoc_token_id
70
+ ) -> torch.Tensor:
71
+ """Helper function to compute per-sample probability of the input sequence.
72
+
73
+ Assumes <eos token> is used to separate inputs from targets in the
74
+ prompt text
75
+ """
76
+ shift_logits, shift_labels = compute_shifted_logits_and_labels(
77
+ logits, encodings, tokenizer, eoc_token_id
78
+ )
79
+
80
+ # Tuple of tensors for unmasked label tokens. The first element of the
81
+ # tuple contains the batch indices; the second element contains the
82
+ # sequence indices.
83
+ unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
84
+ # Tensor where the i^th element is the token_id corresponding to the i^th
85
+ # element of unmasked_indices
86
+ unmasked_token_ids = shift_labels[unmasked_indices]
87
+
88
+ # 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
89
+ target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
90
+ target_idxs = target_idxs.to(shift_logits.device)
91
+
92
+ # Sanity check that every element in batch has at least one unmasked
93
+ # target token
94
+ assert torch.all(
95
+ torch.bincount(target_idxs[:, 0]) != 0
96
+ ), "At least one element in batch has no unmasked target tokens."
97
+
98
+ # Renormalize over tokens to make sure they are proper probabilities via
99
+ # softmax over the token dimension.
100
+ shift_probs = torch.nn.functional.softmax(shift_logits, 2)
101
+
102
+ # Compute the probability of the target sequence (as the product of the
103
+ # probability of the individual tokens in the sequence).
104
+ target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
105
+ for i, j, k in target_idxs:
106
+ target_probs[i] *= shift_probs[i, j, k]
107
+
108
+ return target_probs
109
+
110
+
111
+ def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
112
+ """Helper function to compute per-sample classification loss.
113
+
114
+ Assumes <eos token> is used to separate inputs from targets in the
115
+ prompt text
116
+ """
117
+ shift_logits, shift_labels = compute_shifted_logits_and_labels(
118
+ logits, encodings, tokenizer, eoc_token_id
119
+ )
120
+
121
+ device = shift_logits.device
122
+
123
+ # Loss is computed token-wise, on Tensors of shape
124
+ # [batch_size * (seq_len - 1), vocab_size]
125
+ # and returns a loss tensor of shape
126
+ # [batch_size * (seq_len - 1)]. Most of the tokens will be masked
127
+ # in this computation.
128
+ loss = torch.nn.functional.cross_entropy(
129
+ shift_logits.view(-1, shift_logits.size(-1)),
130
+ shift_labels.view(-1).to(device),
131
+ reduction="none",
132
+ )
133
+
134
+ # Reshape to [batch_size, seq_len - 1]
135
+ loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()
136
+
137
+ # loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
138
+ # that should be ignored in the loss.
139
+ loss_mask = (shift_labels != -100).int().cpu()
140
+
141
+ loss *= loss_mask
142
+
143
+ # Compute per-element loss : sum loss over all (unmasked) tokens and
144
+ # divide by number of variable tokens to obtain tensor of
145
+ # shape [batch_size,]
146
+ loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
147
+ return loss
multimodal/build/lib/open_flamingo/eval/coco_metric.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pycocoevalcap.eval import COCOEvalCap
2
+ from pycocotools.coco import COCO
3
+ import json
4
+
5
+
6
+ def compute_cider(
7
+ result_path,
8
+ annotations_path,
9
+ ):
10
+ # create coco object and coco_result object
11
+ coco = COCO(annotations_path)
12
+ coco_result = coco.loadRes(result_path)
13
+
14
+ # create coco_eval object by taking coco and coco_result
15
+ coco_eval = COCOEvalCap(coco, coco_result)
16
+ coco_eval.params["image_id"] = coco_result.getImgIds()
17
+ coco_eval.evaluate()
18
+
19
+ return coco_eval.eval
20
+
21
+
22
+ def postprocess_captioning_generation(predictions):
23
+ return predictions
multimodal/build/lib/open_flamingo/eval/dataset_zoo/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .aro_datasets import VG_Relation, VG_Attribution, COCO_Order, Flickr30k_Order
2
+ from .retrieval import COCO_Retrieval, Flickr30k_Retrieval
3
+
4
+
5
+ def get_dataset(dataset_name, image_preprocess=None, text_perturb_fn=None, image_perturb_fn=None, download=False, *args, **kwargs):
6
+ """
7
+ Helper function that returns a dataset object with an evaluation function.
8
+ dataset_name: Name of the dataset.
9
+ image_preprocess: Preprocessing function for images.
10
+ text_perturb_fn: A function that takes in a string and returns a string. This is for perturbation experiments.
11
+ image_perturb_fn: A function that takes in a PIL image and returns a PIL image. This is for perturbation experiments.
12
+ download: Whether to allow downloading images if they are not found.
13
+ """
14
+ if dataset_name == "VG_Relation":
15
+ from .aro_datasets import get_visual_genome_relation
16
+ return get_visual_genome_relation(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
17
+ elif dataset_name == "VG_Attribution":
18
+ from .aro_datasets import get_visual_genome_attribution
19
+ return get_visual_genome_attribution(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
20
+ elif dataset_name == "COCO_Order":
21
+ from .aro_datasets import get_coco_order
22
+ return get_coco_order(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
23
+ elif dataset_name == "Flickr30k_Order":
24
+ from .aro_datasets import get_flickr30k_order
25
+ return get_flickr30k_order(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
26
+ elif dataset_name == "COCO_Retrieval":
27
+ from .retrieval import get_coco_retrieval
28
+ return get_coco_retrieval(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
29
+ elif dataset_name == "Flickr30k_Retrieval":
30
+ from .retrieval import get_flickr30k_retrieval
31
+ return get_flickr30k_retrieval(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
32
+ else:
33
+ raise ValueError(f"Unknown dataset {dataset_name}")
multimodal/build/lib/open_flamingo/eval/dataset_zoo/aro_datasets.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import subprocess
4
+
5
+ import numpy as np
6
+
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ from torch.utils.data import Dataset
10
+ from easydict import EasyDict as edict
11
+ from torchvision.datasets.utils import download_url
12
+
13
+ from .perturbations import TextShuffler
14
+ from .constants import ARO_ROOT, COCO_ROOT, FLICKR_ROOT
15
+ from .retrieval import pre_caption
16
+
17
+
18
+ class VG_Relation(Dataset):
19
+ def __init__(self, image_preprocess, text_perturb_fn=None, image_perturb_fn=None, root_dir=ARO_ROOT, download=False):
20
+ '''
21
+ image_preprocess: a function that takes in a PIL image and returns a tensor.
22
+ text_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
23
+ image_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
24
+ root_dir: Directory for the VG-R dataset.
25
+ download: Whether to download the dataset if it does not exist.
26
+ '''
27
+ self.root_dir = root_dir
28
+ annotation_file = os.path.join(root_dir, "visual_genome_relation.json")
29
+ image_dir = os.path.join(root_dir, "images")
30
+ if not os.path.exists(image_dir):
31
+ print("Image Directory for VG_Relation could not be found!")
32
+ if download:
33
+ self.download()
34
+ else:
35
+ raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
36
+
37
+ if not os.path.exists(annotation_file):
38
+ subprocess.call(["gdown", "--id", "1kX2iCHEv0CADL8dSO1nMdW-V0NqIAiP3", "--output", annotation_file])
39
+
40
+ with open(annotation_file, "r") as f:
41
+ self.dataset = json.load(f)
42
+
43
+ self.all_relations = list()
44
+ for item in self.dataset:
45
+ item["image_path"] = os.path.join(image_dir, item["image_path"])
46
+ self.all_relations.append(item["relation_name"])
47
+
48
+ self.image_preprocess = image_preprocess
49
+
50
+ def __len__(self):
51
+ return len(self.dataset)
52
+
53
+ def __getitem__(self, index):
54
+ test_case = self.dataset[index]
55
+ image = Image.open(test_case["image_path"]).convert('RGB')
56
+ # Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
57
+ image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))
58
+
59
+ if self.image_preprocess is not None:
60
+ image = self.image_preprocess(image)
61
+
62
+ # Each test case has a correct and incorrect caption.
63
+ true_caption = test_case["true_caption"]
64
+ false_caption = test_case["false_caption"]
65
+ item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
66
+ return item
67
+
68
+ def download(self):
69
+ os.makedirs(self.root_dir, exist_ok=True)
70
+ image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
71
+ subprocess.call(["gdown", "--no-cookies", "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
72
+ subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)
73
+
74
+
75
+ def evaluate_scores(self, scores):
76
+ """
77
+ Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
78
+ """
79
+ if isinstance(scores, tuple):
80
+ scores_i2t = scores[1]
81
+ scores_t2i = scores[0]
82
+ else:
83
+ scores_t2i = scores
84
+ scores_i2t = scores
85
+
86
+ metrics = {"Accuracy": None}
87
+ preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
88
+ correct_mask = (preds == 1)
89
+ metrics["Accuracy"] = np.mean(correct_mask)
90
+
91
+ all_relations = np.array(self.all_relations)
92
+
93
+ result_records = []
94
+ # Log the accuracy of all relations
95
+ for relation in np.unique(all_relations):
96
+ relation_mask = (all_relations == relation)
97
+ if relation_mask.sum() == 0:
98
+ continue
99
+ result_records.append({
100
+ "Relation": relation,
101
+ "Accuracy": correct_mask[relation_mask].mean(),
102
+ "Count": relation_mask.sum(),
103
+ "Dataset": "Visual Genome Relation"
104
+ })
105
+ return result_records
106
+
107
+
108
+
109
+ class VG_Attribution(Dataset):
110
+ def __init__(self, image_preprocess, text_perturb_fn=None, image_perturb_fn=None, root_dir=ARO_ROOT, download=False):
111
+ '''
112
+ image_preprocess: a function that takes in a PIL image and returns a tensor.
113
+ text_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
114
+ image_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
115
+ root_dir: Directory for the VG-A dataset.
116
+ '''
117
+ self.root_dir = root_dir
118
+ annotation_file = os.path.join(root_dir, "visual_genome_attribution.json")
119
+ image_dir = os.path.join(root_dir, "images")
120
+ if not os.path.exists(image_dir):
121
+ print("Image Directory for VG_Attribution could not be found!")
122
+ if download:
123
+ self.download()
124
+ else:
125
+ raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
126
+
127
+
128
+ if not os.path.exists(annotation_file):
129
+ subprocess.call(["gdown", "--id", "13tWvOrNOLHxl3Rm9cR3geAdHx2qR3-Tw", "--output", annotation_file])
130
+
131
+ with open(annotation_file, "r") as f:
132
+ self.dataset = json.load(f)
133
+
134
+ for item in self.dataset:
135
+ item["image_path"] = os.path.join(image_dir, item["image_path"])
136
+
137
+ # Set of attributes in each test case
138
+ self.all_attributes = [f"{item['attributes'][0]}_{item['attributes'][1]}" for item in self.dataset]
139
+ self.image_preprocess = image_preprocess
140
+
141
+ def __len__(self):
142
+ return len(self.dataset)
143
+
144
+ def __getitem__(self, index):
145
+ test_case = self.dataset[index]
146
+ image = Image.open(test_case["image_path"]).convert('RGB')
147
+ # Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
148
+ image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))
149
+
150
+ if self.image_preprocess is not None:
151
+ image = self.image_preprocess(image)
152
+
153
+ # Each test case has a correct and incorrect caption.
154
+ true_caption = test_case["true_caption"]
155
+ false_caption = test_case["false_caption"]
156
+ item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
157
+ return item
158
+
159
+ def download(self):
160
+ os.makedirs(self.root_dir, exist_ok=True)
161
+ image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
162
+ subprocess.call(["gdown", "--no-cookies", "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
163
+ subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)
164
+
165
+
166
+ def evaluate_scores(self, scores):
167
+ """
168
+ Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
169
+ """
170
+ if isinstance(scores, tuple):
171
+ scores_i2t = scores[1]
172
+ scores_t2i = scores[0]
173
+ else:
174
+ scores_t2i = scores
175
+ scores_i2t = scores
176
+
177
+ preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
178
+ correct_mask = (preds == 1)
179
+ result_records = []
180
+ all_attributes = np.array(self.all_attributes)
181
+ for attr in np.unique(all_attributes):
182
+ attr_mask = (all_attributes == attr)
183
+ if attr_mask.sum() < 25:
184
+ continue
185
+ result_records.append({
186
+ "Attributes": attr,
187
+ "Accuracy": correct_mask[attr_mask].mean(),
188
+ "Count": attr_mask.sum(),
189
+ "Dataset": "Visual Genome Attribution"
190
+ })
191
+ return result_records
192
+
193
+
194
+
195
+
196
+ class COCO_Order(Dataset):
197
+ def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test",
198
+ image_perturb_fn=None, download=False):
199
+ """
200
+ COCO Order Dataset.
201
+ image_preprocess: image preprocessing function
202
+ root_dir: The directory of the coco dataset. This directory should contain test2014 files.
203
+ max_words: Cropping the caption to max_words.
204
+ split: 'val' or 'test'
205
+ image_perturb_fn: not used; for compatibility.
206
+ download: Whether to download the dataset if it does not exist.
207
+ """
208
+ shuffler = TextShuffler()
209
+ perturb_functions = [shuffler.shuffle_nouns_and_adj, shuffler.shuffle_allbut_nouns_and_adj,
210
+ shuffler.shuffle_within_trigrams, shuffler.shuffle_trigrams]
211
+
212
+ self.root_dir = root_dir
213
+ if not os.path.exists(root_dir):
214
+ print("Directory for COCO could not be found!")
215
+ if download:
216
+ print("Downloading COCO now.")
217
+ self.download()
218
+ else:
219
+ raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
220
+
221
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
222
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
223
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
224
+ download_url(urls[split],root_dir)
225
+
226
+ self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
227
+ self.image_preprocess = image_preprocess
228
+ self.image_root = root_dir
229
+
230
+ self.test_cases = []
231
+
232
+ for img_id, ann in tqdm(enumerate(self.annotation)):
233
+ for i, caption in enumerate(ann['caption']):
234
+ test_case = {}
235
+ test_case["image"] = ann["image"]
236
+ test_case["caption_options"] = [pre_caption(caption,max_words)]
237
+
238
+ for perturb_fn in perturb_functions:
239
+ test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
240
+ self.test_cases.append(test_case)
241
+
242
+ def __len__(self):
243
+ return len(self.test_cases)
244
+
245
+ def __getitem__(self, index):
246
+ test_case = self.test_cases[index]
247
+ image_path = os.path.join(self.image_root, test_case["image"])
248
+
249
+ image = Image.open(image_path).convert('RGB')
250
+ if self.image_preprocess is not None:
251
+ image = self.image_preprocess(image)
252
+
253
+ item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
254
+ return item
255
+
256
+ def download(self):
257
+ import subprocess
258
+ os.makedirs(self.root_dir, exist_ok=True)
259
+ #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
260
+ #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
261
+
262
+ subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
263
+ subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
264
+
265
+ subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
266
+ subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
267
+
268
+
269
+ def evaluate_scores(self, scores):
270
+ if isinstance(scores, tuple):
271
+ scores_i2t = scores[0]
272
+ scores_t2i = scores[1].T # Make it N_ims x N_text
273
+
274
+ else:
275
+ scores_t2i = scores
276
+ scores_i2t = scores
277
+
278
+ preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
279
+ correct_mask = (preds == 0)
280
+ records = [{"Precision@1": np.mean(correct_mask)}]
281
+ return records
282
+
283
+
284
+ class Flickr30k_Order(Dataset):
285
+ def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30,
286
+ *args, **kwargs):
287
+ """
288
+ image_preprocess: image preprocessing function
289
+ split: 'val' or 'test'
290
+ root_dir: The directory of the flickr30k images. This should contain the `flickr30k-images` directory that \
291
+ contains all the images.
292
+ """
293
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
294
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
295
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
296
+ if not os.path.exists(root_dir):
297
+ print("Directory for Flickr30k could not be found!")
298
+ flickr_url = "https://forms.illinois.edu/sec/229675"
299
+ raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
300
+
301
+ download_url(urls[split],root_dir)
302
+
303
+ self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
304
+ self.image_preprocess = image_preprocess
305
+ self.root_dir = root_dir
306
+
307
+ self.test_cases = []
308
+
309
+ shuffler = TextShuffler()
310
+ perturb_functions = [shuffler.shuffle_nouns_and_adj, shuffler.shuffle_allbut_nouns_and_adj,
311
+ shuffler.shuffle_within_trigrams, shuffler.shuffle_trigrams]
312
+ for img_id, ann in tqdm(enumerate(self.annotation)):
313
+ for i, caption in enumerate(ann['caption']):
314
+ test_case = {}
315
+ test_case["image"] = ann["image"]
316
+ test_case["caption_options"] = [pre_caption(caption,max_words)]
317
+
318
+ for perturb_fn in perturb_functions:
319
+ test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
320
+ self.test_cases.append(test_case)
321
+
322
+ def __len__(self):
323
+ return len(self.test_cases)
324
+
325
+ def __getitem__(self, index):
326
+ test_case = self.test_cases[index]
327
+ image_path = os.path.join(self.root_dir, test_case["image"])
328
+ image = Image.open(image_path).convert('RGB')
329
+
330
+ if self.image_preprocess is not None:
331
+ image = self.image_preprocess(image)
332
+
333
+ item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
334
+ return item
335
+
336
+ def evaluate_scores(self, scores):
337
+ if isinstance(scores, tuple):
338
+ scores_i2t = scores[0]
339
+ scores_t2i = scores[1].T # Make it N_ims x N_text
340
+ else:
341
+ scores_t2i = scores
342
+ scores_i2t = scores
343
+
344
+ preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
345
+ correct_mask = (preds == 0)
346
+ result_records = [{"Precision@1": np.mean(correct_mask)}]
347
+ return result_records
348
+
349
+
350
+ def get_visual_genome_relation(image_preprocess, text_perturb_fn=None, image_perturb_fn=None, download=False):
351
+ return VG_Relation(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download)
352
+
353
+
354
+ def get_visual_genome_attribution(image_preprocess, text_perturb_fn=None, image_perturb_fn=None, download=False):
355
+ return VG_Attribution(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn,
356
+ image_perturb_fn=image_perturb_fn, download=download)
357
+
358
+ def get_coco_order(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"):
359
+ return COCO_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
360
+ download=download)
361
+
362
+ def get_flickr30k_order(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"):
363
+ return Flickr30k_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
364
+ download=download)
365
+
multimodal/build/lib/open_flamingo/eval/dataset_zoo/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ARO_ROOT = "~/.cache/prerelease_bow"
2
+ COCO_ROOT = "~/.cache/coco/2014"
3
+ FLICKR_ROOT = "~/.cache/flickr30k/images"
multimodal/build/lib/open_flamingo/eval/dataset_zoo/perturbations.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from functools import partial
5
+ import torch.nn.functional as nnf
6
+ from torchvision import transforms as T
7
+
8
+ # A lot of the approaches here are inspired from the wonderful paper from O'Connor and Andreas 2021.
9
+ # https://github.com/lingo-mit/context-ablations
10
+
11
+ def get_text_perturb_fn(text_perturb_fn):
12
+ if text_perturb_fn == "shuffle_nouns_and_adj":
13
+ return shuffle_nouns_and_adj
14
+ elif text_perturb_fn == "shuffle_allbut_nouns_and_adj":
15
+ return shuffle_allbut_nouns_and_adj
16
+ elif text_perturb_fn == "shuffle_within_trigrams":
17
+ return shuffle_within_trigrams
18
+ elif text_perturb_fn == "shuffle_all_words":
19
+ return shuffle_all_words
20
+ elif text_perturb_fn == "shuffle_trigrams":
21
+ return shuffle_trigrams
22
+ elif text_perturb_fn is None:
23
+ return None
24
+ else:
25
+ print("Unknown text perturbation function: {}, returning None".format(text_perturb_fn))
26
+ return None
27
+
28
+
29
+ def get_image_perturb_fn(image_perturb_fn):
30
+ if image_perturb_fn == "shuffle_rows_4":
31
+ return partial(shuffle_rows, n_rows=4)
32
+ elif image_perturb_fn == "shuffle_patches_9":
33
+ return partial(shuffle_patches, n_ratio=3)
34
+ elif image_perturb_fn == "shuffle_cols_4":
35
+ return partial(shuffle_columns, n_cols=4)
36
+ elif image_perturb_fn is None:
37
+ return None
38
+ else:
39
+ print("Unknown image perturbation function: {}, returning None".format(image_perturb_fn))
40
+ return None
41
+
42
+
43
+
44
+ class TextShuffler:
45
+
46
+ def __init__(self):
47
+ import spacy
48
+ self.nlp = spacy.load("en_core_web_sm")
49
+
50
+ def shuffle_nouns_and_adj(self, ex):
51
+
52
+ doc = self.nlp(ex)
53
+ tokens = [token.text for token in doc]
54
+ text = np.array(tokens)
55
+ noun_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS']]
56
+ ## Finding adjectives
57
+ adjective_idx = [i for i, token in enumerate(doc) if token.tag_ in ['JJ', 'JJR', 'JJS']]
58
+ ## Shuffle the nouns of the text
59
+ text[noun_idx] = np.random.permutation(text[noun_idx])
60
+ ## Shuffle the adjectives of the text
61
+ text[adjective_idx] = np.random.permutation(text[adjective_idx])
62
+
63
+ return " ".join(text)
64
+
65
+ def shuffle_all_words(self, ex):
66
+ return " ".join(np.random.permutation(ex.split(" ")))
67
+
68
+
69
+ def shuffle_allbut_nouns_and_adj(self, ex):
70
+ doc = self.nlp(ex)
71
+ tokens = [token.text for token in doc]
72
+ text = np.array(tokens)
73
+ noun_adj_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS', 'JJ', 'JJR', 'JJS']]
74
+ ## Finding adjectives
75
+
76
+ else_idx = np.ones(text.shape[0])
77
+ else_idx[noun_adj_idx] = 0
78
+
79
+ else_idx = else_idx.astype(bool)
80
+ ## Shuffle everything that are nouns or adjectives
81
+ text[else_idx] = np.random.permutation(text[else_idx])
82
+ return " ".join(text)
83
+
84
+
85
+ def get_trigrams(self, sentence):
86
+ # Taken from https://github.com/lingo-mit/context-ablations/blob/478fb18a9f9680321f0d37dc999ea444e9287cc0/code/transformers/src/transformers/data/data_augmentation.py
87
+ trigrams = []
88
+ trigram = []
89
+ for i in range(len(sentence)):
90
+ trigram.append(sentence[i])
91
+ if i % 3 == 2:
92
+ trigrams.append(trigram[:])
93
+ trigram = []
94
+ if trigram:
95
+ trigrams.append(trigram)
96
+ return trigrams
97
+
98
+ def trigram_shuffle(self, sentence):
99
+ trigrams = self.get_trigrams(sentence)
100
+ for trigram in trigrams:
101
+ random.shuffle(trigram)
102
+ return " ".join([" ".join(trigram) for trigram in trigrams])
103
+
104
+
105
+ def shuffle_within_trigrams(self, ex):
106
+ import nltk
107
+ tokens = nltk.word_tokenize(ex)
108
+ shuffled_ex = self.trigram_shuffle(tokens)
109
+ return shuffled_ex
110
+
111
+
112
+ def shuffle_trigrams(self, ex):
113
+ import nltk
114
+ tokens = nltk.word_tokenize(ex)
115
+ trigrams = self.get_trigrams(tokens)
116
+ random.shuffle(trigrams)
117
+ shuffled_ex = " ".join([" ".join(trigram) for trigram in trigrams])
118
+ return shuffled_ex
119
+
120
+
121
+ def _handle_image_4shuffle(x):
122
+ return_image = False
123
+ if not isinstance(x, torch.Tensor):
124
+ # print(f"x is not a tensor: {type(x)}. Trying to handle but fix this or I'll annoy you with this log")
125
+ t = torch.tensor(np.array(x)).unsqueeze(dim=0).float()
126
+ t = t.permute(0, 3, 1, 2)
127
+ return_image = True
128
+ return t, return_image
129
+ if len(x.shape) != 4:
130
+ #print("You did not send a tensor of shape NxCxWxH. Unsqueezing not but fix this or I'll annoy you with this log")
131
+ return x.unsqueeze(dim=0), return_image
132
+ else:
133
+ # Good boi
134
+ return x, return_image
135
+
136
+
137
+ def shuffle_rows(x, n_rows=7):
138
+ """
139
+ Shuffle the rows of the image tensor where each row has a size of 14 pixels.
140
+ Tensor is of shape N x C x W x H
141
+ """
142
+ x, return_image = _handle_image_4shuffle(x)
143
+ patch_size = x.shape[-2]//n_rows
144
+ u = nnf.unfold(x, kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
145
+ # permute the patches of each image in the batch
146
+ pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
147
+ # fold the permuted patches back together
148
+ f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
149
+
150
+ image = f.squeeze() # C W H
151
+ if return_image:
152
+ return T.ToPILImage()(image.type(torch.uint8))
153
+ else:
154
+ return image
155
+
156
+
157
+ def shuffle_columns(x, n_cols=7):
158
+ """
159
+ Shuffle the columns of the image tensor where we'll have n_cols columns.
160
+ Tensor is of shape N x C x W x H
161
+ """
162
+ x, return_image = _handle_image_4shuffle(x)
163
+ patch_size = x.shape[-1]//n_cols
164
+ u = nnf.unfold(x, kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
165
+ # permute the patches of each image in the batch
166
+ pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
167
+ # fold the permuted patches back together
168
+ f = nnf.fold(pu, x.shape[-2:], kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
169
+ image = f.squeeze() # C W H
170
+ if return_image:
171
+ return T.ToPILImage()(image.type(torch.uint8))
172
+ else:
173
+ return image
174
+
175
+
176
+
177
+ def shuffle_patches(x, n_ratio=4):
178
+ """
179
+ Shuffle the rows of the image tensor where each row has a size of 14 pixels.
180
+ Tensor is of shape N x C x W x H
181
+ """
182
+ x, return_image = _handle_image_4shuffle(x)
183
+ patch_size_x = x.shape[-2]//n_ratio
184
+ patch_size_y = x.shape[-1]//n_ratio
185
+ u = nnf.unfold(x, kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
186
+ # permute the patches of each image in the batch
187
+ pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
188
+ # fold the permuted patches back together
189
+ f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
190
+ image = f.squeeze() # C W H
191
+ if return_image:
192
+ return T.ToPILImage()(image.type(torch.uint8))
193
+ else:
194
+ return image
multimodal/build/lib/open_flamingo/eval/dataset_zoo/retrieval.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ from torch.utils.data import Dataset
9
+ from torchvision.datasets.utils import download_url
10
+
11
+ from .constants import COCO_ROOT, FLICKR_ROOT
12
+ from .utils import AverageMeter
13
+
14
+
15
+ def pre_caption(caption,max_words=50):
16
+ caption = re.sub(
17
+ r"([.!\"()*#:;~])",
18
+ ' ',
19
+ caption.lower(),
20
+ )
21
+ caption = re.sub(
22
+ r"\s{2,}",
23
+ ' ',
24
+ caption,
25
+ )
26
+ caption = caption.rstrip('\n')
27
+ caption = caption.strip(' ')
28
+
29
+ #truncate caption
30
+ caption_words = caption.split(' ')
31
+ if len(caption_words)>max_words:
32
+ caption = ' '.join(caption_words[:max_words])
33
+
34
+ return caption
35
+
36
+
37
+ class COCO_Retrieval(Dataset):
38
+ def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test",
39
+ image_perturb_fn=None, download=False):
40
+ """
41
+ COCO Retrieval Dataset.
42
+ image_preprocess: image preprocessing function
43
+ root_dir: The directory of the coco dataset. This directory should contain test2014 files.
44
+ max_words: Cropping the caption to max_words.
45
+ split: 'val' or 'test'
46
+ image_perturb_fn: image perturbation function for patch permutation experiments.
47
+ download: Whether to download the dataset if it does not exist.
48
+ """
49
+ self.root_dir = root_dir
50
+ if not os.path.exists(root_dir):
51
+ print("Directory for COCO could not be found!")
52
+ if download:
53
+ print("Downloading COCO now.")
54
+ self.download()
55
+ else:
56
+ raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
57
+
58
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
59
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
60
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
61
+ download_url(urls[split],root_dir)
62
+
63
+
64
+ self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
65
+ self.image_preprocess = image_preprocess
66
+ self.image_perturb_fn = image_perturb_fn
67
+ self.image_root = root_dir
68
+
69
+ self.text = []
70
+ self.image = []
71
+ self.txt2img = {}
72
+ self.img2txt = {}
73
+
74
+ txt_id = 0
75
+ for img_id, ann in enumerate(self.annotation):
76
+ self.image.append(ann['image'])
77
+ self.img2txt[img_id] = []
78
+ for i, caption in enumerate(ann['caption']):
79
+ self.text.append(pre_caption(caption,max_words))
80
+ self.img2txt[img_id].append(txt_id)
81
+ self.txt2img[txt_id] = img_id
82
+ txt_id += 1
83
+
84
+ def __len__(self):
85
+ return len(self.annotation)
86
+
87
+ def __getitem__(self, index):
88
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
89
+ image = Image.open(image_path).convert('RGB')
90
+
91
+ if self.image_preprocess is not None:
92
+ image = self.image_preprocess(image)
93
+
94
+ if self.image_perturb_fn is not None:
95
+ image = self.image_perturb_fn(image)
96
+
97
+ return {"image": image, "idx": index}
98
+
99
+ def download(self):
100
+ import subprocess
101
+ os.makedirs(self.root_dir, exist_ok=True)
102
+ #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
103
+ #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
104
+
105
+ subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
106
+ subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
107
+
108
+ subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
109
+ subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
110
+
111
+
112
+ def evaluate_scores(self, scores):
113
+ if isinstance(scores, tuple):
114
+ scores_i2t = scores[0]
115
+ scores_t2i = scores[1].T # Make it N_ims x N_text
116
+
117
+ else:
118
+ scores_t2i = scores
119
+ scores_i2t = scores
120
+
121
+ print(f"COCO results across {scores_i2t.shape} samples. ")
122
+ prec_at_1 = AverageMeter()
123
+ prec_at_5 = AverageMeter()
124
+
125
+ # Text retrieval
126
+ tqdm_iterator = tqdm(range(len(self.img2txt)))
127
+ for i in tqdm_iterator:
128
+ top5_captions = np.argsort(scores_i2t[i])[-5:]
129
+ true_captions = self.img2txt[i]
130
+
131
+ prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
132
+ prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)
133
+
134
+ tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")
135
+
136
+ # Image Retrieval
137
+ image_prec_at_1 = AverageMeter()
138
+ image_prec_at_5 = AverageMeter()
139
+
140
+ tqdm_iterator = tqdm(range(len(self.txt2img)))
141
+ for i in tqdm_iterator:
142
+ top5_images = np.argsort(scores_t2i[:, i])[-5:]
143
+ true_image = self.txt2img[i]
144
+
145
+ image_prec_at_1.update(true_image in top5_images[-1:])
146
+ image_prec_at_5.update(true_image in top5_images)
147
+
148
+ tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")
149
+
150
+ records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}]
151
+ return records
152
+
153
+
154
+
155
+ class Flickr30k_Retrieval(Dataset):
156
+ def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30,
157
+ image_perturb_fn=None, *args, **kwargs):
158
+ '''
159
+ Flickr30k dataset for retrieval.
160
+ image_preprocess: image preprocessing function
161
+ root_dir: The directory of the coco dataset. This directory should contain test2014 files.
162
+ max_words: Cropping the caption to max_words.
163
+ split: 'val' or 'test'
164
+ image_perturb_fn: image perturbation function for patch permutation experiments.
165
+ download: Whether to download the dataset if it does not exist.
166
+ '''
167
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
168
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
169
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
170
+
171
+ if not os.path.exists(root_dir):
172
+ print("Directory for Flickr30k could not be found!")
173
+ flickr_url = "https://forms.illinois.edu/sec/229675"
174
+ raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
175
+
176
+ download_url(urls[split],root_dir)
177
+
178
+ self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
179
+ self.image_preprocess = image_preprocess
180
+ self.image_perturb_fn = image_perturb_fn
181
+ self.root_dir = root_dir
182
+
183
+ self.text = []
184
+ self.image = []
185
+ self.txt2img = {}
186
+ self.img2txt = {}
187
+
188
+ txt_id = 0
189
+ for img_id, ann in enumerate(self.annotation):
190
+ self.image.append(ann['image'])
191
+ self.img2txt[img_id] = []
192
+ for i, caption in enumerate(ann['caption']):
193
+ self.text.append(pre_caption(caption,max_words))
194
+ self.img2txt[img_id].append(txt_id)
195
+ self.txt2img[txt_id] = img_id
196
+ txt_id += 1
197
+
198
+ def __len__(self):
199
+ return len(self.annotation)
200
+
201
+ def __getitem__(self, index):
202
+ image_path = os.path.join(self.root_dir, self.annotation[index]['image'])
203
+ image = Image.open(image_path).convert('RGB')
204
+ if self.image_preprocess is not None:
205
+ image = self.image_preprocess(image)
206
+ if self.image_perturb_fn is not None:
207
+ image = self.image_perturb_fn(image)
208
+
209
+ return {"image": image, "idx": index}
210
+
211
+ def evaluate_scores(self, scores):
212
+ if isinstance(scores, tuple):
213
+ scores_i2t = scores[0]
214
+ scores_t2i = scores[1].T # Make it N_ims x N_text
215
+
216
+ else:
217
+ scores_t2i = scores
218
+ scores_i2t = scores
219
+
220
+ print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ")
221
+ prec_at_1 = AverageMeter()
222
+ prec_at_5 = AverageMeter()
223
+
224
+ # Text retrieval
225
+ tqdm_iterator = tqdm(range(len(self.img2txt)))
226
+ for i in tqdm_iterator:
227
+ top5_captions = np.argsort(scores_i2t[i])[-5:]
228
+ true_captions = self.img2txt[i]
229
+
230
+ prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
231
+ prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)
232
+
233
+ tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")
234
+
235
+ # Image Retrieval
236
+ image_prec_at_1 = AverageMeter()
237
+ image_prec_at_5 = AverageMeter()
238
+
239
+ tqdm_iterator = tqdm(range(len(self.txt2img)))
240
+ for i in tqdm_iterator:
241
+ top5_images = np.argsort(scores_t2i[:, i])[-5:]
242
+ true_image = self.txt2img[i]
243
+
244
+ image_prec_at_1.update(true_image in top5_images[-1:])
245
+ image_prec_at_5.update(true_image in top5_images)
246
+
247
+ tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")
248
+
249
+ records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}]
250
+ return records
251
+
252
+ def download(self):
253
+ raise NotImplementedError("Flickr30k dataset is not available for download.")
254
+
255
+
256
+
257
+ def get_coco_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"):
258
+ dataset = COCO_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
259
+ download=download)
260
+ return dataset
261
+
262
+
263
+ def get_flickr30k_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"):
264
+ dataset = Flickr30k_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
265
+ download=download)
266
+ return dataset
multimodal/build/lib/open_flamingo/eval/dataset_zoo/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class AverageMeter(object):
2
+ def __init__(self):
3
+ self.reset()
4
+
5
+ def reset(self):
6
+ self.val = 0
7
+ self.avg = 0
8
+ self.sum = 0
9
+ self.count = 0
10
+
11
+ def update(self, val, n=1):
12
+ self.val = val
13
+ self.sum += val * n
14
+ self.count += n
15
+ self.avg = self.sum / self.count
multimodal/build/lib/open_flamingo/eval/eval_datasets.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision.datasets import ImageFolder
7
+
8
+ from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
9
+
10
+
11
+ class COCOFlickrDataset(Dataset):
12
+ def __init__(
13
+ self,
14
+ image_dir_path,
15
+ annotations_path,
16
+ is_flickr=False,
17
+ ):
18
+ self.image_dir_path = image_dir_path
19
+ self.annotations = json.load(open(annotations_path))["annotations"]
20
+ self.is_flickr = is_flickr
21
+
22
+ def __len__(self):
23
+ return len(self.annotations)
24
+
25
+ def get_img_path(self, idx):
26
+ if self.is_flickr:
27
+ return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
28
+ else:
29
+ return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg"
30
+
31
+ def __getitem__(self, idx):
32
+ image = Image.open(self.get_img_path(idx))
33
+ caption = self.annotations[idx]["caption"]
34
+ return {
35
+ "image": image,
36
+ "caption": caption,
37
+ "image_id": self.annotations[idx]["image_id"],
38
+ }
39
+
40
+
41
+ class VQADataset(Dataset):
42
+ def __init__(
43
+ self,
44
+ image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/",
45
+ question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json",
46
+ annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json",
47
+ vqa_dataset="vqa",
48
+ ):
49
+ self.questions = json.load(open(question_path, "r"))["questions"]
50
+ self.answers = json.load(open(annotations_path, "r"))["annotations"]
51
+ self.image_dir_path = image_dir_path
52
+ self.vqa_dataset = vqa_dataset
53
+
54
+ def __len__(self):
55
+ return len(self.questions)
56
+
57
+ def get_img_path(self, question):
58
+ if self.vqa_dataset == "vqa":
59
+ return os.path.join(
60
+ self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
61
+ )
62
+ elif self.vqa_dataset == "ok_vqa":
63
+ return os.path.join(
64
+ self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
65
+ )
66
+ else:
67
+ raise Exception(f"Unknown VQA dataset {self.vqa_dataset}")
68
+
69
+ def __getitem__(self, idx):
70
+ question = self.questions[idx]
71
+ answers = self.answers[idx]
72
+ img_path = self.get_img_path(question)
73
+ image = Image.open(img_path)
74
+ return {
75
+ "image": image,
76
+ "question": question["question"],
77
+ "answers": [a["answer"] for a in answers["answers"]],
78
+ "question_id": question["question_id"],
79
+ }
80
+
81
+
82
+ class ImageNetDataset(ImageFolder):
83
+ """Class to represent the ImageNet1k dataset."""
84
+
85
+ def __init__(self, root, **kwargs):
86
+ super().__init__(root=root, **kwargs)
87
+
88
+ def __getitem__(self, idx):
89
+ sample, target = super().__getitem__(idx)
90
+ target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
91
+ return {
92
+ "image": sample,
93
+ "class_id": target, # numeric ID of the ImageNet class
94
+ "class_name": target_label, # human-readable name of ImageNet class
95
+ }
96
+
97
+
98
+ if __name__ == "__main__":
99
+ gqa_dataset = GQADataset()
100
+ for sample in gqa_dataset:
101
+ print(sample)
multimodal/build/lib/open_flamingo/eval/evaluate.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from math import ceil
4
+ import os
5
+ import random
6
+ import uuid
7
+ from collections import defaultdict
8
+ from typing import Callable
9
+ import time
10
+ import cv2
11
+ import webdataset as wds
12
+ from sklearn.metrics import recall_score, average_precision_score
13
+
14
+ import more_itertools
15
+ import numpy as np
16
+ import torch
17
+ from coco_metric import compute_cider, postprocess_captioning_generation
18
+ from eval_datasets import VQADataset
19
+ from tqdm import tqdm
20
+ from collections import Counter
21
+
22
+ from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
23
+ from open_flamingo.eval.classification import (
24
+ compute_per_sample_probs,
25
+ compute_per_sample_loss,
26
+ )
27
+ from open_flamingo.eval.imagenet_utils import (
28
+ openai_imagenet_classnames,
29
+ IMAGENET_1K_CLASS_ID_TO_LABEL,
30
+ )
31
+
32
+ from open_flamingo.src.factory import create_model_and_transforms
33
+ from PIL import Image
34
+ from io import BytesIO
35
+ import base64
36
+ from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
37
+ import string
38
+ from open_flamingo.eval.task.reg import evaluate_reg
39
+ from open_flamingo.eval.task.gqa import GQADataset
40
+ from open_flamingo.eval.task.vl_checklist import evaluate_vlc
41
+ from open_flamingo.eval.task.crepe import evaluate_crepe
42
+ from open_flamingo.eval.task.caption import evaluate_coco_flickr
43
+ from open_flamingo.eval.task.utils import is_correct, get_iou
44
+ from open_flamingo.eval.task.cola import evaluate_cola
45
+ from open_flamingo.eval.task.gqa import evaluate_gqa
46
+
47
+ def expand2square(pil_img, background_color):
48
+ width, height = pil_img.size
49
+ if width == height:
50
+ return pil_img
51
+ elif width > height:
52
+ result = Image.new(pil_img.mode, (width, width), background_color)
53
+ result.paste(pil_img, (0, (width - height) // 2))
54
+ return result
55
+ else:
56
+ result = Image.new(pil_img.mode, (height, height), background_color)
57
+ result.paste(pil_img, ((height - width) // 2, 0))
58
+ return result
59
+
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
62
+ parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
63
+ parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
64
+ parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
65
+ parser.add_argument("--checkpoint_path", type=str, required=True)
66
+ parser.add_argument(
67
+ "--results_file", type=str, default=None, help="JSON file to save results"
68
+ )
69
+
70
+ # Trial arguments
71
+ parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
72
+ parser.add_argument(
73
+ "--num_trials",
74
+ type=int,
75
+ default=1,
76
+ help="Number of trials to run for each shot using different demonstrations",
77
+ )
78
+ parser.add_argument(
79
+ "--trial_seeds",
80
+ nargs="+",
81
+ default=[0],
82
+ help="Seeds to use for each trial for picking demonstrations and eval sets",
83
+ )
84
+ parser.add_argument(
85
+ "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
86
+ )
87
+
88
+ parser.add_argument("--batch_size", type=int, default=8)
89
+
90
+ # Per-dataset evaluation flags
91
+ parser.add_argument(
92
+ "--eval_coco",
93
+ action="store_true",
94
+ default=False,
95
+ help="Whether to evaluate on COCO.",
96
+ )
97
+ parser.add_argument(
98
+ "--eval_vqav2",
99
+ action="store_true",
100
+ default=False,
101
+ help="Whether to evaluate on VQAV2.",
102
+ )
103
+ parser.add_argument(
104
+ "--eval_ok_vqa",
105
+ action="store_true",
106
+ default=False,
107
+ help="Whether to evaluate on OK-VQA.",
108
+ )
109
+ parser.add_argument(
110
+ "--eval_imagenet",
111
+ action="store_true",
112
+ default=False,
113
+ help="Whether to evaluate on ImageNet.",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--eval_flickr30",
118
+ action="store_true",
119
+ default=False,
120
+ help="Whether to evaluate on Flickr30.",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--eval_refcoco",
125
+ action="store_true",
126
+ default=False,
127
+ help="Whether to evaluate on RefCOCO.",
128
+ )
129
+
130
+ # Dataset arguments
131
+
132
+ ## Flickr30 Dataset
133
+ parser.add_argument(
134
+ "--flickr_image_dir_path",
135
+ type=str,
136
+ help="Path to the flickr30/flickr30k_images directory.",
137
+ default=None,
138
+ )
139
+ parser.add_argument(
140
+ "--flickr_annotations_json_path",
141
+ type=str,
142
+ help="Path to the dataset_flickr30k_coco_style.json file.",
143
+ default=None,
144
+ )
145
+
146
+ ## COCO Dataset
147
+ parser.add_argument(
148
+ "--coco_image_dir_path",
149
+ type=str,
150
+ help="Path to the flickr30/flickr30k_images directory.",
151
+ default=None,
152
+ )
153
+ parser.add_argument(
154
+ "--coco_annotations_json_path",
155
+ type=str,
156
+ default=None,
157
+ )
158
+
159
+ ## VQAV2 Dataset
160
+ parser.add_argument(
161
+ "--vqav2_image_dir_path",
162
+ type=str,
163
+ default=None,
164
+ )
165
+ parser.add_argument(
166
+ "--vqav2_questions_json_path",
167
+ type=str,
168
+ default=None,
169
+ )
170
+ parser.add_argument(
171
+ "--vqav2_annotations_json_path",
172
+ type=str,
173
+ default=None,
174
+ )
175
+
176
+ ## OK-VQA Dataset
177
+ parser.add_argument(
178
+ "--ok_vqa_image_dir_path",
179
+ type=str,
180
+ help="Path to the vqav2/train2014 directory.",
181
+ default=None,
182
+ )
183
+ parser.add_argument(
184
+ "--ok_vqa_questions_json_path",
185
+ type=str,
186
+ help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
187
+ default=None,
188
+ )
189
+ parser.add_argument(
190
+ "--ok_vqa_annotations_json_path",
191
+ type=str,
192
+ help="Path to the v2_mscoco_train2014_annotations.json file.",
193
+ default=None,
194
+ )
195
+
196
+ ## Imagenet dataset
197
+ parser.add_argument("--imagenet_root", type=str, default="/tmp")
198
+
199
+ ## RefCOCO dataset
200
+ parser.add_argument("--refcoco_tsvfile", type=str, default=None)
201
+
202
+ parser.add_argument(
203
+ "--location_token_num",
204
+ default=1000,
205
+ type=int,
206
+ )
207
+ # distributed training
208
+ parser.add_argument(
209
+ "--dist-url",
210
+ default="env://",
211
+ type=str,
212
+ help="url used to set up distributed training",
213
+ )
214
+ parser.add_argument(
215
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
216
+ )
217
+ parser.add_argument(
218
+ "--horovod",
219
+ default=False,
220
+ action="store_true",
221
+ help="Use horovod for distributed training.",
222
+ )
223
+ parser.add_argument(
224
+ "--no-set-device-rank",
225
+ default=False,
226
+ action="store_true",
227
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
228
+ )
229
+ parser.add_argument(
230
+ "--dist",
231
+ default=False,
232
+ action="store_true",
233
+ )
234
+ parser.add_argument(
235
+ "--lora",
236
+ default=False,
237
+ action="store_true",
238
+ )
239
+ parser.add_argument(
240
+ "--lora_r",
241
+ default=16,
242
+ type=int,
243
+ required=False,
244
+ )
245
+ parser.add_argument(
246
+ "--legacy",
247
+ default=False,
248
+ action="store_true",
249
+ )
250
+ parser.add_argument(
251
+ "--special",
252
+ default=False,
253
+ action="store_true",
254
+ )
255
+ parser.add_argument(
256
+ "--id",
257
+ default=0,
258
+ type=int,
259
+ required=False,
260
+ )
261
+
262
+ parser.add_argument(
263
+ "--eval_gqa",
264
+ default=False,
265
+ action="store_true",
266
+ )
267
+ parser.add_argument(
268
+ "--use_sam",
269
+ default=None,
270
+ type=str,
271
+ required=False,
272
+ )
273
+ parser.add_argument(
274
+ "--add_visual_token",
275
+ default=False,
276
+ action="store_true",
277
+ )
278
+ parser.add_argument(
279
+ "--use_format_v2",
280
+ default=False,
281
+ action="store_true",
282
+ )
283
+ parser.add_argument(
284
+ "--eval_aro",
285
+ default=False,
286
+ action="store_true",
287
+ )
288
+ parser.add_argument(
289
+ "--eval_pisc",
290
+ default=False,
291
+ action="store_true",
292
+ )
293
+ parser.add_argument(
294
+ "--eval_reg",
295
+ default=False,
296
+ action="store_true",
297
+ )
298
+ parser.add_argument(
299
+ "--eval_vlc",
300
+ default=False,
301
+ action="store_true",
302
+ )
303
+ parser.add_argument(
304
+ "--eval_crepe",
305
+ default=False,
306
+ action="store_true",
307
+ )
308
+ parser.add_argument(
309
+ "--eval_cola",
310
+ default=False,
311
+ action="store_true",
312
+ )
313
+ parser.add_argument(
314
+ "--level",
315
+ default=4,
316
+ type=int,
317
+ )
318
+ parser.add_argument(
319
+ "--type",
320
+ default="swap",
321
+ type=str,
322
+ )
323
+ parser.add_argument(
324
+ "--choose_left_right",
325
+ default=False,
326
+ action="store_true",
327
+ )
328
+
329
+
330
+ class OKVQAPostProcess():
331
+ def __init__(self):
332
+ self._lemmatizer = None
333
+
334
+ def _lemmatize(self, answers):
335
+ def apply(answer):
336
+ doc = self.lemmatizer(answer)
337
+
338
+ words = []
339
+ for token in doc:
340
+ if token.pos_ in ["NOUN", "VERB"]:
341
+ words.append(token.lemma_)
342
+ else:
343
+ words.append(token.text)
344
+ answer = " ".join(words)
345
+
346
+ return answer
347
+
348
+ return [apply(answer) for answer in answers]
349
+
350
+ @property
351
+ def lemmatizer(self):
352
+ if self._lemmatizer is None:
353
+ try:
354
+ import spacy
355
+
356
+ self._lemmatizer = spacy.load("en_core_web_sm")
357
+ except ImportError:
358
+ logging.error(
359
+ """
360
+ Please install spacy and en_core_web_sm model to apply lemmatization.
361
+ python -m spacy download en_core_web_sm
362
+ OR
363
+ import spacy.cli
364
+ spacy.cli.download("en_core_web_sm")
365
+ """
366
+ )
367
+ exit(1)
368
+
369
+ return self._lemmatizer
370
+
371
+
372
+ def main():
373
+ args = parser.parse_args()
374
+ if args.dist:
375
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
376
+ print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
377
+ device_id = init_distributed_device(args)
378
+ else:
379
+ args.rank = 0
380
+ args.world_size = 1
381
+ print(f"rank: {args.rank} world_size: {args.world_size}")
382
+
383
+ if "sam" in args.checkpoint_path:
384
+ args.use_sam = "vit_l"
385
+
386
+ args.add_visual_token = True
387
+ if "lora" in args.checkpoint_path:
388
+ args.lora = True
389
+
390
+
391
+ args.add_pe = False
392
+ args.add_box = True
393
+ args.relation = False
394
+ args.enhance_data = False
395
+ args.use_format_v2 = True
396
+
397
+
398
+
399
+ import hashlib
400
+ args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
401
+
402
+ # load model
403
+ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
404
+ args.vision_encoder_path,
405
+ args.vision_encoder_pretrained,
406
+ args.lm_path,
407
+ args.lm_tokenizer_path,
408
+ location_token_num=args.location_token_num,
409
+ lora=args.lora,
410
+ lora_r=16,
411
+ use_sam=args.use_sam,
412
+ add_visual_token=args.add_visual_token,
413
+ use_format_v2=args.use_format_v2,
414
+ add_box=args.add_box,
415
+ add_pe=args.add_pe,
416
+ add_relation=args.relation,
417
+ enhance_data=args.enhance_data,
418
+ )
419
+ flamingo.use_format_v2 = args.use_format_v2
420
+ if args.special:
421
+ flamingo.special = True
422
+ else:
423
+ flamingo.special = False
424
+ if args.legacy:
425
+ flamingo.legacy = True
426
+ print("use legacy evaluation")
427
+ flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
428
+ flamingo.expr_name = args.checkpoint_path.split("/")[-2]
429
+ if args.rank == 0:
430
+ print("legacy", True if hasattr(flamingo, "legacy") else False)
431
+ print("step:", flamingo.step_num)
432
+ print("expr:", flamingo.expr_name)
433
+ print("use format v2:", flamingo.use_format_v2)
434
+ print(args)
435
+ checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
436
+ model_state_dict = {}
437
+ for key in checkpoint["model_state_dict"].keys():
438
+ model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
439
+ if "vision_encoder.logit_scale"in model_state_dict:
440
+ # previous checkpoint has some unnecessary weights
441
+ del model_state_dict["vision_encoder.logit_scale"]
442
+ del model_state_dict["vision_encoder.visual.proj"]
443
+ del model_state_dict["vision_encoder.visual.ln_post.weight"]
444
+ del model_state_dict["vision_encoder.visual.ln_post.bias"]
445
+ flamingo.load_state_dict(model_state_dict, strict=True)
446
+ results = defaultdict(list)
447
+ if args.eval_coco:
448
+ print("Evaluating on COCO...")
449
+ cider_score = evaluate_coco_flickr(
450
+ model=flamingo,
451
+ tokenizer=tokenizer,
452
+ image_processor=image_processor,
453
+ batch_size=args.batch_size,
454
+ vis_embed_size=vis_embed_size,
455
+ rank=args.rank,
456
+ world_size=args.world_size,
457
+ id=args.id,
458
+ )
459
+ results["coco"].append({"score": cider_score})
460
+
461
+ if args.eval_ok_vqa:
462
+ print("Evaluating on OK-VQA...")
463
+ for shot in args.shots:
464
+ scores = []
465
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
466
+ ok_vqa_score = evaluate_vqa(
467
+ model=flamingo,
468
+ tokenizer=tokenizer,
469
+ image_processor=image_processor,
470
+ batch_size=args.batch_size,
471
+ image_dir_path=args.ok_vqa_image_dir_path,
472
+ questions_json_path=args.ok_vqa_questions_json_path,
473
+ annotations_json_path=args.ok_vqa_annotations_json_path,
474
+ vqa_dataset="ok_vqa",
475
+ vis_embed_size=vis_embed_size,
476
+ rank=args.rank,
477
+ world_size=args.world_size,
478
+ id=args.id,
479
+ )
480
+ results["ok_vqa"].append(
481
+ {"shots": shot, "score": ok_vqa_score}
482
+ )
483
+
484
+ if args.eval_vqav2:
485
+ print("Evaluating on VQAv2...")
486
+ for shot in args.shots:
487
+ scores = []
488
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
489
+ vqa_score = evaluate_vqa(
490
+ model=flamingo,
491
+ tokenizer=tokenizer,
492
+ image_processor=image_processor,
493
+ batch_size=args.batch_size,
494
+ image_dir_path=args.vqav2_image_dir_path,
495
+ questions_json_path=args.vqav2_questions_json_path,
496
+ annotations_json_path=args.vqav2_annotations_json_path,
497
+ vqa_dataset="vqa",
498
+ vis_embed_size=vis_embed_size,
499
+ rank=args.rank,
500
+ world_size=args.world_size,
501
+ id=args.id,
502
+ )
503
+ results["vqav2"].append(
504
+ {"shots": shot, "score": vqa_score}
505
+ )
506
+
507
+ if args.eval_gqa:
508
+ print("Evaluating on GQA...")
509
+ gqa_score = evaluate_gqa(
510
+ model=flamingo,
511
+ tokenizer=tokenizer,
512
+ image_processor=image_processor,
513
+ batch_size=args.batch_size,
514
+ vis_embed_size=vis_embed_size,
515
+ rank=args.rank,
516
+ world_size=args.world_size,
517
+ id=args.id,
518
+ )
519
+ results["gqa"].append(
520
+ {"score": gqa_score}
521
+ )
522
+
523
+ if args.eval_refcoco:
524
+ print("Evaluating on RefCOCO...")
525
+ refcoco_score = evaluate_refcoco(
526
+ model=flamingo,
527
+ tokenizer=tokenizer,
528
+ image_processor=image_processor,
529
+ batch_size=args.batch_size,
530
+ device=args.device,
531
+ tsvfile=args.refcoco_tsvfile,
532
+ vis_embed_size=vis_embed_size,
533
+ rank=args.rank,
534
+ world_size=args.world_size,
535
+ id=args.id,
536
+ )
537
+ results["refcoco"].append(
538
+ {"score": refcoco_score}
539
+ )
540
+ if args.eval_aro:
541
+ print("Evaluating on ARO...")
542
+ aro_score = evaluate_aro(
543
+ model=flamingo,
544
+ tokenizer=tokenizer,
545
+ image_processor=image_processor,
546
+ vis_embed_size=vis_embed_size,
547
+ rank=args.rank,
548
+ world_size=args.world_size,
549
+ id=args.id,
550
+ choose_left_right=args.choose_left_right,
551
+ )
552
+ results["aro"].append(
553
+ {"score": aro_score}
554
+ )
555
+ if args.eval_pisc:
556
+ print("Evaluating on ARO...")
557
+ aro_score = evaluate_pisc(
558
+ model=flamingo,
559
+ tokenizer=tokenizer,
560
+ image_processor=image_processor,
561
+ batch_size=args.batch_size,
562
+ device=args.device,
563
+ tsvfile=args.refcoco_tsvfile,
564
+ vis_embed_size=vis_embed_size,
565
+ rank=args.rank,
566
+ world_size=args.world_size,
567
+ id=args.id,
568
+ )
569
+ results["pisc"].append(
570
+ {"score": aro_score}
571
+ )
572
+ if args.eval_reg:
573
+ print("Evaluating on Referring Expression Generation...")
574
+ cider = evaluate_reg(
575
+ model=flamingo,
576
+ tokenizer=tokenizer,
577
+ image_processor=image_processor,
578
+ vis_embed_size=vis_embed_size,
579
+ rank=args.rank,
580
+ world_size=args.world_size,
581
+ id=args.id,
582
+ )
583
+ results["reg"].append(
584
+ {"score": cider}
585
+ )
586
+ if args.eval_vlc:
587
+ print("Evaluating on VL-checklist...")
588
+ vlc_score = evaluate_vlc(
589
+ model=flamingo,
590
+ tokenizer=tokenizer,
591
+ image_processor=image_processor,
592
+ vis_embed_size=vis_embed_size,
593
+ rank=args.rank,
594
+ world_size=args.world_size,
595
+ id=args.id,
596
+ )
597
+ results["vlc"].append(
598
+ {"score": vlc_score}
599
+ )
600
+ if args.eval_crepe:
601
+ print("Evaluating on CREPE...")
602
+ crepe_score = evaluate_crepe(
603
+ model=flamingo,
604
+ tokenizer=tokenizer,
605
+ image_processor=image_processor,
606
+ vis_embed_size=vis_embed_size,
607
+ rank=args.rank,
608
+ world_size=args.world_size,
609
+ id=args.id,
610
+ level=args.level,
611
+ type=args.type,
612
+ )
613
+ results["crepe"].append(
614
+ {"score": crepe_score}
615
+ )
616
+ if args.eval_cola:
617
+ print("Evaluating on COLA...")
618
+ cola_score = evaluate_cola(
619
+ model=flamingo,
620
+ tokenizer=tokenizer,
621
+ image_processor=image_processor,
622
+ vis_embed_size=vis_embed_size,
623
+ rank=args.rank,
624
+ world_size=args.world_size,
625
+ id=args.id,
626
+ )
627
+ results["cola"].append(
628
+ {"score": cola_score}
629
+ )
630
+
631
+ def prepare_batch_images(batch, image_processor):
632
+ batch_images = None
633
+ for b in batch:
634
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
635
+ if batch_images is None:
636
+ batch_images = b_image
637
+ else:
638
+ batch_images = torch.cat([batch_images, b_image], dim=0)
639
+ return batch_images
640
+
641
+ def get_outputs(
642
+ model,
643
+ batch_images,
644
+ attention_mask,
645
+ max_generation_length,
646
+ min_generation_length,
647
+ num_beams,
648
+ length_penalty,
649
+ input_ids,
650
+ image_start_index_list=None,
651
+ image_nums=None,
652
+ bad_words_ids=None,
653
+ ):
654
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
655
+ outputs = model.generate(
656
+ batch_images,
657
+ input_ids,
658
+ attention_mask=attention_mask,
659
+ max_new_tokens=max_generation_length,
660
+ min_length=min_generation_length,
661
+ num_beams=num_beams,
662
+ length_penalty=length_penalty,
663
+ image_start_index_list=image_start_index_list,
664
+ image_nums=image_nums,
665
+ bad_words_ids=bad_words_ids,
666
+ )
667
+
668
+ outputs = outputs[:, len(input_ids[0]) :]
669
+ return outputs
670
+
671
+
672
+ def evaluate_vqa(
673
+ model,
674
+ tokenizer,
675
+ image_processor,
676
+ batch_size,
677
+ image_dir_path=None,
678
+ questions_json_path=None,
679
+ annotations_json_path=None,
680
+ vqa_dataset="vqa",
681
+ vis_embed_size=None,
682
+ rank=0,
683
+ world_size=1,
684
+ id=0,
685
+ ):
686
+ """
687
+ Evaluate a model on VQA datasets. Currently supports VQA v2.0.
688
+
689
+ Args:
690
+ model (nn.Module): model to evaluate
691
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
692
+ image_processor : image processor for the model
693
+ batch_size (int): batch size
694
+ image_dir_path (str): path to image directory
695
+ questions_json_path (str): path to questions json file
696
+ annotations_json_path (str): path to annotations json file
697
+ seed (int, optional): random seed. Defaults to 42.
698
+ max_generation_length (int, optional): max generation length. Defaults to 5.
699
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
700
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
701
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
702
+ query_set_size (int, optional): size of the query set. Defaults to 2048.
703
+ num_shots (int, optional): number of shots to use. Defaults to 8.
704
+ device (int, optional): device to use. Defaults to -1 (cpu).
705
+ num_workers (int, optional): number of workers to use. Defaults to 4.
706
+ vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
707
+ Returns:
708
+ float: accuracy score
709
+ """
710
+ if world_size > 1:
711
+ torch.distributed.barrier()
712
+ if vqa_dataset == "gqa":
713
+ eval_dataset = GQADataset()
714
+ else:
715
+ eval_dataset = VQADataset(
716
+ image_dir_path=image_dir_path,
717
+ question_path=questions_json_path,
718
+ annotations_path=annotations_json_path,
719
+ vqa_dataset=vqa_dataset,
720
+ )
721
+ postprocessor = OKVQAPostProcess()
722
+ try:
723
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
724
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
725
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
726
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
727
+ except:
728
+ pass
729
+ def get_prompt(sample):
730
+ return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
731
+ # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
732
+
733
+ model.eval().cuda()
734
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
735
+ if "peft" in lang_encoder_name:
736
+ lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
737
+ predictions = []
738
+ tokenizer.padding_side = "left"
739
+ if world_size > 1:
740
+ torch.distributed.barrier()
741
+ this_tot = 0
742
+ for ii, batch in enumerate(more_itertools.chunked(
743
+ tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
744
+ )):
745
+ if ii % world_size != rank:
746
+ continue
747
+ batch_images = prepare_batch_images(
748
+ batch=batch,
749
+ image_processor=image_processor,
750
+ ).cuda()
751
+ batch_text = [get_prompt(s) for s in batch]
752
+ encodings = tokenizer(
753
+ batch_text,
754
+ return_tensors="pt",
755
+ padding="longest",
756
+ truncation=True,
757
+ max_length=2000,
758
+ )
759
+ input_ids = encodings["input_ids"].cuda()
760
+ attention_mask = encodings["attention_mask"].cuda()
761
+ skip_special_tokens = True
762
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
763
+ if rank == 0:
764
+ tqdm.write("use legacy model")
765
+ for i in range(len(input_ids)):
766
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
767
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
768
+ input_ids[i, media_token_index - 1] = media_token_id
769
+ input_ids[i, media_token_index] = pad_token_id
770
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
771
+ input_ids[i, endofmedia_token_index] = bos_token_id
772
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
773
+ image_start_index_list = [[x] for x in image_start_index_list]
774
+ image_nums = [1] * len(input_ids)
775
+ if "llama" in lang_encoder_name:
776
+ attention_mask[input_ids == 0] = 0
777
+ outputs = get_outputs(
778
+ model=model,
779
+ batch_images=batch_images,
780
+ attention_mask=attention_mask,
781
+ max_generation_length=10,
782
+ min_generation_length=1,
783
+ num_beams=5,
784
+ length_penalty=0,
785
+ input_ids=input_ids,
786
+ image_start_index_list=image_start_index_list,
787
+ image_nums=image_nums,
788
+ )
789
+ # postprocess begin
790
+ new_predictions = [
791
+ out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
792
+ ]
793
+ if vqa_dataset == "ok_vqa":
794
+ new_predictions = postprocessor._lemmatize(new_predictions)
795
+ if model.special:
796
+ for i in range(len(new_predictions)):
797
+ for answer, _ in Counter(batch[i]['answers']).most_common():
798
+ if answer in new_predictions[i]:
799
+ new_predictions[i] = answer
800
+ break
801
+ if "cant" in new_predictions[i] and "no" == answer:
802
+ new_predictions[i] = answer
803
+ break
804
+ if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
805
+ new_predictions[i] = answer
806
+ break
807
+
808
+ this_tot += 1
809
+ if rank == 0 and this_tot % 20 == 0:
810
+ for i in range(1):
811
+ tqdm.write("model output: " + new_predictions[i])
812
+
813
+ predictions.extend(
814
+ [
815
+ {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
816
+ for p, sample in zip(new_predictions, batch)
817
+ ]
818
+ )
819
+ with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
820
+ f.write(json.dumps(predictions))
821
+ print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
822
+
823
+ time.sleep(10)
824
+ if world_size > 1:
825
+ torch.distributed.barrier()
826
+ if rank == 0:
827
+ print(f"evaluate on rank {rank}. world size is {world_size}")
828
+ predictions = []
829
+ for rank_i in range(world_size):
830
+ print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
831
+ predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
832
+ os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
833
+ print("num:", len(predictions))
834
+ # save the predictions to a temporary file
835
+ random_uuid = str(uuid.uuid4())
836
+ with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
837
+ f.write(json.dumps(predictions, indent=4))
838
+
839
+ if vqa_dataset == "gqa":
840
+ acc = compute_gqa_accuracy(predictions)
841
+ else:
842
+ acc = compute_vqa_accuracy(
843
+ f"{vqa_dataset}results_{random_uuid}.json",
844
+ questions_json_path,
845
+ annotations_json_path,
846
+ vqa_dataset=vqa_dataset,
847
+ )
848
+ print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
849
+ os.makedirs("eval_results", exist_ok=True)
850
+ with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
851
+ f.write(json.dumps(predictions, indent=2))
852
+
853
+ # delete the temporary file
854
+ os.remove(f"{vqa_dataset}results_{random_uuid}.json")
855
+ else:
856
+ time.sleep(5)
857
+ acc = 0.0
858
+ if world_size > 1:
859
+ torch.distributed.barrier()
860
+ return acc
861
+
862
+
863
+ def evaluate_refcoco(
864
+ model,
865
+ tokenizer,
866
+ image_processor,
867
+ batch_size,
868
+ tsvfile,
869
+ max_generation_length=20,
870
+ num_beams=3,
871
+ length_penalty=-2.0,
872
+ device=-1,
873
+ vis_embed_size=None,
874
+ rank=0,
875
+ world_size=1,
876
+ id=0,
877
+ ):
878
+ model.eval().cuda()
879
+ loc_token_ids = []
880
+ for i in range(1000):
881
+ loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
882
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
883
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
884
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
885
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
886
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
887
+ object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
888
+ # all_ids = set(range(model.lang_encoder.lm_head.out_features))
889
+ # bad_words_ids = list(all_ids - set(loc_token_ids))
890
+ # bad_words_ids = [[b] for b in bad_words_ids]
891
+ # min_loc_token_id = min(loc_token_ids)
892
+ # max_loc_token_id = max(loc_token_ids)
893
+ total = 0
894
+ correct = 0
895
+ ious = []
896
+ if "refcocog" in tsvfile:
897
+ dataset_name = "refcocog"
898
+ elif "refcocoplus" in tsvfile:
899
+ dataset_name = "refcocoplus"
900
+ else:
901
+ dataset_name = "refcoco"
902
+ with open(tsvfile, "r") as f:
903
+ lines = f.readlines()
904
+ pbar = tqdm(lines, disable=(rank != 0))
905
+ for ii, line in enumerate(pbar):
906
+ if ii % world_size != rank:
907
+ continue
908
+ total += 1
909
+ line = line.rstrip()
910
+ uniq_id, image_id, text, region_coord, image = line.split("\t")
911
+
912
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
913
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
914
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
915
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
916
+
917
+ gt_box = np.array(list(map(float, region_coord.split(","))))
918
+ width = image.width
919
+ height = image.height
920
+ image = image.resize((224, 224))
921
+ gt_box = gt_box / np.array([width, height, width, height]) * 224
922
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
923
+ text = text.rstrip('.').strip().replace('"', '').capitalize()
924
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text}<|#endofobject#|><|#visual#|>"]
925
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
926
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
927
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
928
+
929
+ encodings = tokenizer(
930
+ prompt,
931
+ padding="longest",
932
+ truncation=True,
933
+ return_tensors="pt",
934
+ max_length=2000,
935
+ )
936
+ input_ids = encodings["input_ids"]
937
+ attention_mask = encodings["attention_mask"]
938
+ # attention_mask[input_ids == prebox_token_id] = 0
939
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
940
+ image_start_index_list = [[x] for x in image_start_index_list]
941
+ image_nums = [1] * len(input_ids)
942
+ vision_x = batch_images.cuda()
943
+ lang_x = input_ids.cuda()
944
+ attention_mask = attention_mask.cuda()
945
+
946
+ model.debug_id = 0
947
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
948
+ outputs = model(
949
+ vision_x=vision_x,
950
+ lang_x=lang_x,
951
+ attention_mask=attention_mask,
952
+ labels=None,
953
+ image_nums=image_nums,
954
+ image_start_index_list=image_start_index_list,
955
+ added_bbox_list=None,
956
+ add_box=False,
957
+ )
958
+ boxes = outputs["boxes"]
959
+ scores = outputs["scores"]
960
+ boxes = boxes[scores >= scores[0]*0.5]
961
+ scores = scores[scores >= scores[0]*0.5]
962
+
963
+ text = text.lower().strip()
964
+ if text.split(" ")[0] not in ["a", "an", "the", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "several", "some"]:
965
+ text = "a " + text
966
+ losses = []
967
+ for box, score in zip(boxes, scores):
968
+ this_prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>There is<|#object#|><|#previsual#|><|#prebox#|><|#object#|> {text}"]
969
+ encodings = tokenizer(
970
+ this_prompt,
971
+ padding="longest",
972
+ truncation=True,
973
+ return_tensors="pt",
974
+ max_length=2000,
975
+ )
976
+ input_ids = encodings["input_ids"]
977
+ attention_mask = encodings["attention_mask"]
978
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
979
+ image_start_index_list = [[x] for x in image_start_index_list]
980
+ image_nums = [1] * len(input_ids)
981
+ vision_x = batch_images.cuda()
982
+ lang_x = input_ids.cuda()
983
+ attention_mask = attention_mask.cuda()
984
+ added_bbox_list = [torch.tensor(box / 224).cuda().unsqueeze(0).clamp(0, 0.99)]
985
+ labels = lang_x.clone()
986
+ start_idx = (lang_x == object_token_id).nonzero()[-1, -1]
987
+ labels[0, :start_idx+1] = -100
988
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
989
+ outputs = model(
990
+ vision_x=vision_x,
991
+ lang_x=lang_x,
992
+ attention_mask=attention_mask,
993
+ labels=labels,
994
+ image_nums=image_nums,
995
+ image_start_index_list=image_start_index_list,
996
+ added_bbox_list=added_bbox_list,
997
+ add_box=True,
998
+ )
999
+ # print(tokenizer.decode(outputs.logits[0, start_idx].sort(descending=True).indices[:10]))
1000
+ loss = outputs.loss.detach().cpu()
1001
+ losses.append((loss.sum() / (loss != 0).sum()).item())
1002
+ chosen_idx = np.array(losses).argmin()
1003
+ pred_box = boxes[chosen_idx]
1004
+ if chosen_idx != 0:
1005
+ tqdm.write(f"{text}|{chosen_idx}|{scores[chosen_idx]}")
1006
+ iou = get_iou(pred_box, gt_box)
1007
+ if iou >= 0.5:
1008
+ correct += 1
1009
+ # else:
1010
+ # if rank == 0:
1011
+ # tqdm.write(text.rstrip('.').strip().lower())
1012
+ # open_cv_image = np.array(image)
1013
+ # # Convert RGB to BGR
1014
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1015
+ # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
1016
+ # open_cv_image = cv2.rectangle(open_cv_image, gt_box[:2].astype(int), gt_box[2:].astype(int), (0, 255, 0), 2)
1017
+ # cv2.imwrite(f"refcocog_result/{ii}_{iou}_{text}.jpg", open_cv_image)
1018
+ pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
1019
+ # open_cv_image = np.array(image)
1020
+ # # Convert RGB to BGR
1021
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1022
+ # for box, score in zip(boxes, scores):
1023
+ # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
1024
+ # cv2.imwrite("output.jpg", open_cv_image)
1025
+ # print(boxes)
1026
+ # print(scores)
1027
+ # exit()
1028
+
1029
+
1030
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1031
+ f.write(json.dumps([total, correct]))
1032
+ if world_size > 1:
1033
+ torch.distributed.barrier()
1034
+ if rank == 0:
1035
+ total = 0
1036
+ correct = 0
1037
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1038
+ for rank_i in range(world_size):
1039
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1040
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1041
+ total += total_part
1042
+ correct += correct_part
1043
+ score = correct / total
1044
+ print("score:", score)
1045
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1046
+ pass
1047
+ else:
1048
+ score = 0.0
1049
+ if world_size > 1:
1050
+ torch.distributed.barrier()
1051
+ return score
1052
+
1053
+
1054
+
1055
+ # def preprocess_visual_info(Text):
1056
+ # text = Text.split(" ")
1057
+ # for is_idx, t in enumerate(text):
1058
+ # if t == "is":
1059
+ # break
1060
+ # the_idx = is_idx
1061
+ # while text[the_idx] != "the":
1062
+ # the_idx -= 1
1063
+ # obj_A = " ".join(text[the_idx+1:is_idx])
1064
+ # second_the_idx = len(text) - 1
1065
+ # while text[second_the_idx] != "the":
1066
+ # second_the_idx -= 1
1067
+ # obj_B = " ".join(text[second_the_idx+1:])
1068
+ # visual_obj_A = f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
1069
+ # visual_obj_B = f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
1070
+ # Text = Text.replace(obj_A, f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
1071
+ # Text = Text.replace(obj_B, f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
1072
+ # return Text, obj_A, obj_B, visual_obj_A, visual_obj_B
1073
+
1074
+
1075
+ def preprocess_visual_info(Text):
1076
+ text = Text.split(" ")
1077
+ for is_idx, t in enumerate(text):
1078
+ if t == "is":
1079
+ break
1080
+ the_idx = is_idx
1081
+ while text[the_idx] != "the":
1082
+ the_idx -= 1
1083
+ obj_A = " ".join(text[the_idx+1:is_idx])
1084
+ second_the_idx = len(text) - 1
1085
+ while text[second_the_idx] != "the":
1086
+ second_the_idx -= 1
1087
+ obj_B = " ".join(text[second_the_idx+1:])
1088
+ relation = " ".join(text[is_idx+1:second_the_idx])
1089
+ visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
1090
+ visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
1091
+ Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
1092
+ return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
1093
+
1094
+
1095
+
1096
+
1097
+ def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
1098
+ assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
1099
+ encodings = tokenizer(
1100
+ prompt,
1101
+ padding="longest",
1102
+ truncation=True,
1103
+ return_tensors="pt",
1104
+ max_length=2000,
1105
+ )
1106
+ input_ids = encodings["input_ids"]
1107
+ attention_mask = encodings["attention_mask"]
1108
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1109
+ image_start_index_list = [[x] for x in image_start_index_list]
1110
+ image_nums = [1] * len(input_ids)
1111
+ vision_x = batch_images.cuda()
1112
+ lang_x = input_ids.cuda()
1113
+ attention_mask = attention_mask.cuda()
1114
+
1115
+ model.debug_id = 0
1116
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
1117
+ outputs = model(
1118
+ vision_x=vision_x,
1119
+ lang_x=lang_x,
1120
+ attention_mask=attention_mask,
1121
+ labels=None,
1122
+ image_nums=image_nums,
1123
+ image_start_index_list=image_start_index_list,
1124
+ added_bbox_list=visual_box_list,
1125
+ add_box=visual_box_list is not None,
1126
+ relations=None,
1127
+ debug_mode=False,
1128
+ )
1129
+ boxes = outputs["boxes"]
1130
+ scores = outputs["scores"]
1131
+ if debug:
1132
+ import pdb; pdb.set_trace()
1133
+ if return_all:
1134
+ return boxes, scores
1135
+ if len(scores) == 0:
1136
+ return None, None
1137
+ else:
1138
+ return boxes[scores.argmax()], scores.max()
1139
+
1140
+
1141
+ def evaluate_aro(
1142
+ model,
1143
+ tokenizer,
1144
+ image_processor,
1145
+ vis_embed_size=None,
1146
+ rank=0,
1147
+ world_size=1,
1148
+ id=0,
1149
+ add_visual=True,
1150
+ subset=False,
1151
+ choose_left_right=False,
1152
+ ):
1153
+ # os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
1154
+ dataset_name = "aro"
1155
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1156
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1157
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1158
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1159
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1160
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1161
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
1162
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
1163
+ model.eval().cuda()
1164
+ total = 0
1165
+ n_top1 = 0
1166
+ n_top5 = 0
1167
+ from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
1168
+ vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
1169
+ if subset:
1170
+ subset_idx = json.load(open("aro_subset.json"))
1171
+ pbar = tqdm(subset_idx, disable=(rank != 0))
1172
+ else:
1173
+ pbar = tqdm(vgr_dataset, disable=(rank != 0))
1174
+ for ii, sample in enumerate(pbar):
1175
+ if subset:
1176
+ ORI_IDX = int(sample)
1177
+ sample = vgr_dataset[sample]
1178
+ if ii % world_size != rank:
1179
+ continue
1180
+ image = sample["image_options"][0]
1181
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
1182
+ image = image.resize((224, 224))
1183
+
1184
+ text = sample["caption_options"][1] # 1 is true caption
1185
+ # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
1186
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1187
+ text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
1188
+
1189
+
1190
+ first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
1191
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
1192
+ first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
1193
+
1194
+ if first_box is None:
1195
+ text_A = "the " + obj_A
1196
+ added_bbox_list = None
1197
+ else:
1198
+ text_A = visual_obj_A
1199
+ added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
1200
+
1201
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
1202
+ pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
1203
+ prebox_token_id, return_all=True)
1204
+
1205
+ if pre_boxes is None:
1206
+ pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
1207
+ pre_scores = [1.0]
1208
+
1209
+ logits_list = []
1210
+ # pre_boxes = [pre_boxes[0]]
1211
+ # pre_scores = [pre_scores[0]]
1212
+ for pre_box, pre_score in zip(pre_boxes, pre_scores):
1213
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
1214
+
1215
+ encodings = tokenizer(
1216
+ prompt,
1217
+ padding="longest",
1218
+ truncation=True,
1219
+ return_tensors="pt",
1220
+ max_length=512,
1221
+ )
1222
+ input_ids = encodings["input_ids"]
1223
+ attention_mask = encodings["attention_mask"]
1224
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1225
+ image_start_index_list = [[x] for x in image_start_index_list]
1226
+ image_nums = [1] * len(input_ids)
1227
+ vision_x = batch_images.cuda()
1228
+ lang_x = input_ids.cuda()
1229
+ attention_mask = attention_mask.cuda()
1230
+ labels = lang_x.clone()
1231
+ added_bbox_list = None
1232
+ if add_visual:
1233
+ added_bbox_list = []
1234
+ if first_box is not None:
1235
+ added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
1236
+ if pre_box is not None:
1237
+ added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
1238
+ if added_bbox_list is not None and len(added_bbox_list) == 0:
1239
+ added_bbox_list = None
1240
+
1241
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1242
+ outputs = model(
1243
+ vision_x=vision_x,
1244
+ lang_x=lang_x,
1245
+ attention_mask=attention_mask,
1246
+ labels=labels,
1247
+ image_nums=image_nums,
1248
+ image_start_index_list=image_start_index_list,
1249
+ added_bbox_list=added_bbox_list,
1250
+ add_box=added_bbox_list is not None,
1251
+ relations=None,
1252
+ )
1253
+ logits_list.append([pre_score, outputs.logits])
1254
+ pre_scores = np.array([x[0] for x in logits_list])
1255
+ final_probs = 0.0
1256
+ for score, (_, logits) in zip(pre_scores, logits_list):
1257
+ final_probs += score * logits.softmax(-1)
1258
+ assert input_ids.shape[:2] == final_probs.shape[:2]
1259
+ _rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, obj_B, topk=5)
1260
+ if is_top1:
1261
+ n_top1 += 1
1262
+ if is_top5:
1263
+ n_top5 += 1
1264
+ total += 1
1265
+ pbar.set_description(f"acc@top1: {n_top1 / total:.4f} | acc@top5: {n_top5 / total:.4f} | {_rank}")
1266
+
1267
+
1268
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1269
+ f.write(json.dumps([total, n_top1, n_top5]))
1270
+ if world_size > 1:
1271
+ torch.distributed.barrier()
1272
+ if rank == 0:
1273
+ total = 0
1274
+ n_top1 = 0
1275
+ n_top5 = 0
1276
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1277
+ for rank_i in range(world_size):
1278
+ [total_part, n_top1_part, n_top5_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1279
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1280
+ total += total_part
1281
+ n_top1 += n_top1_part
1282
+ n_top5 += n_top5_part
1283
+ acc_top1 = n_top1 / total
1284
+ acc_top5 = n_top5 / total
1285
+ print("acc_top1:", acc_top1, "acc_top5:", acc_top5, "total:", total)
1286
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc_top1}_{acc_top5}_{total}_{subset}"), "w") as f:
1287
+ pass
1288
+ else:
1289
+ score = 0.0
1290
+ if world_size > 1:
1291
+ torch.distributed.barrier()
1292
+ return score
1293
+
1294
+
1295
+ def evaluate_pisc(
1296
+ model,
1297
+ tokenizer,
1298
+ image_processor,
1299
+ batch_size,
1300
+ tsvfile,
1301
+ max_generation_length=20,
1302
+ num_beams=3,
1303
+ length_penalty=-2.0,
1304
+ device=-1,
1305
+ vis_embed_size=None,
1306
+ rank=0,
1307
+ world_size=1,
1308
+ id=0,
1309
+ add_visual=True,
1310
+ ):
1311
+ from open_flamingo.train.instruction_template import PISC_TEMPLATES
1312
+ dataset_name = "pisc"
1313
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1314
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1315
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1316
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1317
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1318
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1319
+ model.train().cuda()
1320
+
1321
+ dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
1322
+ pbar = tqdm(dataset, disable=(rank != 0))
1323
+
1324
+ rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
1325
+ rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
1326
+ gt = []
1327
+ pred_scores = []
1328
+ for III, sample in enumerate(pbar):
1329
+ if III % world_size != rank:
1330
+ continue
1331
+ image_path, dataset, data = sample
1332
+ image = Image.open(image_path)
1333
+ size = image_processor.transforms[0].size
1334
+ image = image.resize((size, size))
1335
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1336
+ boxA = data[0]
1337
+ boxB = data[1]
1338
+ gt_relation = data[2]
1339
+ losses = []
1340
+ for i_rel, option_rel in enumerate(rel_id_to_type):
1341
+ text = PISC_TEMPLATES[0].format(relation=option_rel)
1342
+ added_bbox = [
1343
+ torch.tensor([boxA]).cuda(),
1344
+ torch.tensor([boxB]).cuda(),
1345
+ ]
1346
+ caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
1347
+ encodings = tokenizer(
1348
+ caption,
1349
+ padding="longest",
1350
+ truncation=True,
1351
+ return_tensors="pt",
1352
+ max_length=2000,
1353
+ )
1354
+ input_ids = encodings["input_ids"]
1355
+ attention_mask = encodings["attention_mask"]
1356
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1357
+ image_start_index_list = [[x] for x in image_start_index_list]
1358
+ image_nums = [1] * len(input_ids)
1359
+ vision_x = batch_images.cuda()
1360
+ lang_x = input_ids.cuda()
1361
+ attention_mask = attention_mask.cuda()
1362
+
1363
+ labels = lang_x.clone()
1364
+ labels[labels == tokenizer.pad_token_id] = -100
1365
+ if add_visual:
1366
+ # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
1367
+ # endofattr_next_token_index[1] += 1
1368
+ # endofattr_next_token_id = labels[endofattr_next_token_index]
1369
+ # </obj><visual><box></attr>NEXT_WORD
1370
+ # </obj> predict NEXT_WORD
1371
+ # <visual><box></attr> predict nothing
1372
+ labels[labels == visual_token_id] = -100
1373
+ labels[labels == box_token_id] = -100
1374
+ labels[labels == endofattr_token_id] = -100
1375
+ # labels[endofattr_next_token_index] = -100
1376
+ labels[:, 0] = -100
1377
+ answer_token_id = tokenizer(" Answer").input_ids[0]
1378
+ answer_token_loc = (input_ids == answer_token_id).nonzero()
1379
+ for batch_idx, idx in answer_token_loc:
1380
+ labels[batch_idx][:idx+2] = -100
1381
+
1382
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1383
+ outputs = model(
1384
+ vision_x=vision_x,
1385
+ lang_x=lang_x,
1386
+ attention_mask=attention_mask,
1387
+ labels=labels,
1388
+ image_nums=image_nums,
1389
+ image_start_index_list=image_start_index_list,
1390
+ added_bbox_list=added_bbox,
1391
+ add_box=added_bbox is not None,
1392
+ )
1393
+ loss_total = outputs.loss.reshape(labels.shape[0], -1)
1394
+ loss = loss_total.sum() / (loss_total != 0).sum()
1395
+ losses.append(loss.item())
1396
+ pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
1397
+ gt.append(rel_type_to_id[gt_relation])
1398
+ gt = np.array(gt)
1399
+ pred_scores = np.array(pred_scores)
1400
+ pred = pred_scores.argmax(1)
1401
+
1402
+
1403
+ print("total num:", len(gt))
1404
+ recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
1405
+ print("recalls:", recalls)
1406
+
1407
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1408
+ f.write(json.dumps([gt.tolist(), pred.tolist()]))
1409
+ if world_size > 1:
1410
+ torch.distributed.barrier()
1411
+ if rank == 0:
1412
+ gt = []
1413
+ pred = []
1414
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1415
+ for rank_i in range(world_size):
1416
+ [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1417
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1418
+ gt.extend(gt_part)
1419
+ pred.extend(pred_part)
1420
+ print("total num:", len(gt))
1421
+ recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
1422
+ print("recalls:", recalls)
1423
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
1424
+ f.write(f"{gt}\n")
1425
+ f.write(f"{pred}\n")
1426
+ f.write(f"{recalls}\n")
1427
+ score = 0.0
1428
+ if world_size > 1:
1429
+ torch.distributed.barrier()
1430
+ return score
1431
+
1432
+
1433
+
1434
+ if __name__ == "__main__":
1435
+ main()
multimodal/build/lib/open_flamingo/eval/evaluate_debug.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from math import ceil
4
+ import os
5
+ import random
6
+ import uuid
7
+ from collections import defaultdict
8
+ from typing import Callable
9
+ import time
10
+ import cv2
11
+
12
+ import more_itertools
13
+ import numpy as np
14
+ import torch
15
+ from coco_metric import compute_cider, postprocess_captioning_generation
16
+ from eval_datasets import VQADataset, GQADataset
17
+ from tqdm import tqdm
18
+ from collections import Counter
19
+
20
+ from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
21
+ from open_flamingo.eval.classification import (
22
+ compute_per_sample_probs,
23
+ compute_per_sample_loss,
24
+ )
25
+ from open_flamingo.eval.imagenet_utils import (
26
+ openai_imagenet_classnames,
27
+ IMAGENET_1K_CLASS_ID_TO_LABEL,
28
+ )
29
+
30
+ from open_flamingo.src.factory import create_model_and_transforms
31
+ from PIL import Image
32
+ from io import BytesIO
33
+ import base64
34
+ from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
35
+ import string
36
+ from lavis.datasets.builders import load_dataset
37
+
38
+
39
+ def get_iou(box1, box2):
40
+ # box1 and box2 should be in the format [x1, y1, x2, y2]
41
+ intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
42
+ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
43
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
44
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
45
+ union = area_box1 + area_box2 - intersection
46
+ iou = intersection / union if union > 0 else 0
47
+ return iou
48
+
49
+ def expand2square(pil_img, background_color):
50
+ width, height = pil_img.size
51
+ if width == height:
52
+ return pil_img
53
+ elif width > height:
54
+ result = Image.new(pil_img.mode, (width, width), background_color)
55
+ result.paste(pil_img, (0, (width - height) // 2))
56
+ return result
57
+ else:
58
+ result = Image.new(pil_img.mode, (height, height), background_color)
59
+ result.paste(pil_img, ((height - width) // 2, 0))
60
+ return result
61
+
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
64
+ parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
65
+ parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
66
+ parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
67
+ parser.add_argument("--checkpoint_path", type=str, required=True)
68
+ parser.add_argument(
69
+ "--results_file", type=str, default=None, help="JSON file to save results"
70
+ )
71
+
72
+ # Trial arguments
73
+ parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
74
+ parser.add_argument(
75
+ "--num_trials",
76
+ type=int,
77
+ default=1,
78
+ help="Number of trials to run for each shot using different demonstrations",
79
+ )
80
+ parser.add_argument(
81
+ "--trial_seeds",
82
+ nargs="+",
83
+ default=[0],
84
+ help="Seeds to use for each trial for picking demonstrations and eval sets",
85
+ )
86
+ parser.add_argument(
87
+ "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
88
+ )
89
+
90
+ parser.add_argument("--batch_size", type=int, default=8)
91
+
92
+ # Per-dataset evaluation flags
93
+ parser.add_argument(
94
+ "--eval_coco",
95
+ action="store_true",
96
+ default=False,
97
+ help="Whether to evaluate on COCO.",
98
+ )
99
+ parser.add_argument(
100
+ "--eval_vqav2",
101
+ action="store_true",
102
+ default=False,
103
+ help="Whether to evaluate on VQAV2.",
104
+ )
105
+ parser.add_argument(
106
+ "--eval_ok_vqa",
107
+ action="store_true",
108
+ default=False,
109
+ help="Whether to evaluate on OK-VQA.",
110
+ )
111
+ parser.add_argument(
112
+ "--eval_imagenet",
113
+ action="store_true",
114
+ default=False,
115
+ help="Whether to evaluate on ImageNet.",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--eval_flickr30",
120
+ action="store_true",
121
+ default=False,
122
+ help="Whether to evaluate on Flickr30.",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--eval_refcoco",
127
+ action="store_true",
128
+ default=False,
129
+ help="Whether to evaluate on RefCOCO.",
130
+ )
131
+
132
+ # Dataset arguments
133
+
134
+ ## Flickr30 Dataset
135
+ parser.add_argument(
136
+ "--flickr_image_dir_path",
137
+ type=str,
138
+ help="Path to the flickr30/flickr30k_images directory.",
139
+ default=None,
140
+ )
141
+ parser.add_argument(
142
+ "--flickr_annotations_json_path",
143
+ type=str,
144
+ help="Path to the dataset_flickr30k_coco_style.json file.",
145
+ default=None,
146
+ )
147
+
148
+ ## COCO Dataset
149
+ parser.add_argument(
150
+ "--coco_image_dir_path",
151
+ type=str,
152
+ help="Path to the flickr30/flickr30k_images directory.",
153
+ default=None,
154
+ )
155
+ parser.add_argument(
156
+ "--coco_annotations_json_path",
157
+ type=str,
158
+ default=None,
159
+ )
160
+
161
+ ## VQAV2 Dataset
162
+ parser.add_argument(
163
+ "--vqav2_image_dir_path",
164
+ type=str,
165
+ default=None,
166
+ )
167
+ parser.add_argument(
168
+ "--vqav2_questions_json_path",
169
+ type=str,
170
+ default=None,
171
+ )
172
+ parser.add_argument(
173
+ "--vqav2_annotations_json_path",
174
+ type=str,
175
+ default=None,
176
+ )
177
+
178
+ ## OK-VQA Dataset
179
+ parser.add_argument(
180
+ "--ok_vqa_image_dir_path",
181
+ type=str,
182
+ help="Path to the vqav2/train2014 directory.",
183
+ default=None,
184
+ )
185
+ parser.add_argument(
186
+ "--ok_vqa_questions_json_path",
187
+ type=str,
188
+ help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
189
+ default=None,
190
+ )
191
+ parser.add_argument(
192
+ "--ok_vqa_annotations_json_path",
193
+ type=str,
194
+ help="Path to the v2_mscoco_train2014_annotations.json file.",
195
+ default=None,
196
+ )
197
+
198
+ ## Imagenet dataset
199
+ parser.add_argument("--imagenet_root", type=str, default="/tmp")
200
+
201
+ ## RefCOCO dataset
202
+ parser.add_argument("--refcoco_tsvfile", type=str, default=None)
203
+
204
+ parser.add_argument(
205
+ "--location_token_num",
206
+ default=1000,
207
+ type=int,
208
+ )
209
+ # distributed training
210
+ parser.add_argument(
211
+ "--dist-url",
212
+ default="env://",
213
+ type=str,
214
+ help="url used to set up distributed training",
215
+ )
216
+ parser.add_argument(
217
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
218
+ )
219
+ parser.add_argument(
220
+ "--horovod",
221
+ default=False,
222
+ action="store_true",
223
+ help="Use horovod for distributed training.",
224
+ )
225
+ parser.add_argument(
226
+ "--no-set-device-rank",
227
+ default=False,
228
+ action="store_true",
229
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
230
+ )
231
+ parser.add_argument(
232
+ "--dist",
233
+ default=False,
234
+ action="store_true",
235
+ )
236
+ parser.add_argument(
237
+ "--lora",
238
+ default=False,
239
+ action="store_true",
240
+ )
241
+ parser.add_argument(
242
+ "--lora_r",
243
+ default=16,
244
+ type=int,
245
+ required=False,
246
+ )
247
+ parser.add_argument(
248
+ "--legacy",
249
+ default=False,
250
+ action="store_true",
251
+ )
252
+ parser.add_argument(
253
+ "--special",
254
+ default=False,
255
+ action="store_true",
256
+ )
257
+ parser.add_argument(
258
+ "--id",
259
+ default=0,
260
+ type=int,
261
+ required=False,
262
+ )
263
+
264
+ parser.add_argument(
265
+ "--eval_gqa",
266
+ default=False,
267
+ action="store_true",
268
+ )
269
+ parser.add_argument(
270
+ "--use_sam",
271
+ default=None,
272
+ type=str,
273
+ required=False,
274
+ )
275
+ parser.add_argument(
276
+ "--add_visual_token",
277
+ default=False,
278
+ action="store_true",
279
+ )
280
+ parser.add_argument(
281
+ "--use_format_v2",
282
+ default=False,
283
+ action="store_true",
284
+ )
285
+
286
+
287
+ class OKVQAPostProcess():
288
+ def __init__(self):
289
+ self._lemmatizer = None
290
+
291
+ def _lemmatize(self, answers):
292
+ def apply(answer):
293
+ doc = self.lemmatizer(answer)
294
+
295
+ words = []
296
+ for token in doc:
297
+ if token.pos_ in ["NOUN", "VERB"]:
298
+ words.append(token.lemma_)
299
+ else:
300
+ words.append(token.text)
301
+ answer = " ".join(words)
302
+
303
+ return answer
304
+
305
+ return [apply(answer) for answer in answers]
306
+
307
+ @property
308
+ def lemmatizer(self):
309
+ if self._lemmatizer is None:
310
+ try:
311
+ import spacy
312
+
313
+ self._lemmatizer = spacy.load("en_core_web_sm")
314
+ except ImportError:
315
+ logging.error(
316
+ """
317
+ Please install spacy and en_core_web_sm model to apply lemmatization.
318
+ python -m spacy download en_core_web_sm
319
+ OR
320
+ import spacy.cli
321
+ spacy.cli.download("en_core_web_sm")
322
+ """
323
+ )
324
+ exit(1)
325
+
326
+ return self._lemmatizer
327
+
328
+
329
+ def main():
330
+ args = parser.parse_args()
331
+ if args.dist:
332
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
333
+ print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
334
+ device_id = init_distributed_device(args)
335
+ else:
336
+ args.rank = 0
337
+ args.world_size = 1
338
+ print(f"rank: {args.rank} world_size: {args.world_size}")
339
+
340
+ if "sam" in args.checkpoint_path:
341
+ args.use_sam = "vit_l"
342
+
343
+ args.add_visual_token = True
344
+ if "lora" in args.checkpoint_path:
345
+ args.lora = True
346
+
347
+
348
+ args.add_pe = False
349
+ args.add_box = False
350
+ args.relation = False
351
+ if "debug" in args.checkpoint_path:
352
+ # args.add_pe = True
353
+ args.add_box = True
354
+ if "box" in args.checkpoint_path:
355
+ args.add_box = True
356
+ if "pe" in args.checkpoint_path:
357
+ args.add_pe = True
358
+ if "rel" in args.checkpoint_path:
359
+ args.relation = True
360
+ args.add_pe = False
361
+ if "previsual" in args.checkpoint_path:
362
+ args.use_format_v2 = True
363
+ args.relation = False
364
+
365
+
366
+
367
+ # load model
368
+ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
369
+ args.vision_encoder_path,
370
+ args.vision_encoder_pretrained,
371
+ args.lm_path,
372
+ args.lm_tokenizer_path,
373
+ location_token_num=args.location_token_num,
374
+ lora=args.lora,
375
+ lora_r=16,
376
+ use_sam=args.use_sam,
377
+ add_visual_token=args.add_visual_token,
378
+ use_format_v2=args.use_format_v2,
379
+ add_box=args.add_box,
380
+ add_pe=args.add_pe,
381
+ add_relation=args.relation,
382
+ )
383
+ flamingo.use_format_v2 = args.use_format_v2
384
+ if args.special:
385
+ flamingo.special = True
386
+ else:
387
+ flamingo.special = False
388
+ if args.legacy:
389
+ flamingo.legacy = True
390
+ print("use legacy evaluation")
391
+ flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
392
+ flamingo.expr_name = args.checkpoint_path.split("/")[-2]
393
+ if args.rank == 0:
394
+ print("legacy", True if hasattr(flamingo, "legacy") else False)
395
+ print("step:", flamingo.step_num)
396
+ print("expr:", flamingo.expr_name)
397
+ print("use format v2:", flamingo.use_format_v2)
398
+ print(args)
399
+ checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
400
+ model_state_dict = {}
401
+ for key in checkpoint["model_state_dict"].keys():
402
+ model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
403
+ if "vision_encoder.logit_scale"in model_state_dict:
404
+ # previous checkpoint has some unnecessary weights
405
+ del model_state_dict["vision_encoder.logit_scale"]
406
+ del model_state_dict["vision_encoder.visual.proj"]
407
+ del model_state_dict["vision_encoder.visual.ln_post.weight"]
408
+ del model_state_dict["vision_encoder.visual.ln_post.bias"]
409
+ flamingo.load_state_dict(model_state_dict, strict=True)
410
+ results = defaultdict(list)
411
+ if args.eval_coco:
412
+ print("Evaluating on COCO...")
413
+ for shot in args.shots:
414
+ scores = []
415
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
416
+ cider_score = evaluate_coco_flickr(
417
+ model=flamingo,
418
+ tokenizer=tokenizer,
419
+ image_processor=image_processor,
420
+ batch_size=args.batch_size,
421
+ image_dir_path=args.coco_image_dir_path,
422
+ annotations_json_path=args.coco_annotations_json_path,
423
+ device=args.device,
424
+ seed=seed,
425
+ vis_embed_size=vis_embed_size,
426
+ rank=args.rank,
427
+ world_size=args.world_size,
428
+ id=args.id,
429
+ )
430
+ print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
431
+ scores.append(cider_score)
432
+ print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
433
+ results["coco"].append(
434
+ {"shots": shot, "trials": scores, "mean": np.mean(scores)}
435
+ )
436
+
437
+ if args.eval_ok_vqa:
438
+ print("Evaluating on OK-VQA...")
439
+ for shot in args.shots:
440
+ scores = []
441
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
442
+ ok_vqa_score = evaluate_vqa(
443
+ model=flamingo,
444
+ tokenizer=tokenizer,
445
+ image_processor=image_processor,
446
+ batch_size=args.batch_size,
447
+ image_dir_path=args.ok_vqa_image_dir_path,
448
+ questions_json_path=args.ok_vqa_questions_json_path,
449
+ annotations_json_path=args.ok_vqa_annotations_json_path,
450
+ vqa_dataset="ok_vqa",
451
+ vis_embed_size=vis_embed_size,
452
+ rank=args.rank,
453
+ world_size=args.world_size,
454
+ id=args.id,
455
+ )
456
+ results["ok_vqa"].append(
457
+ {"shots": shot, "score": ok_vqa_score}
458
+ )
459
+
460
+ if args.eval_vqav2:
461
+ print("Evaluating on VQAv2...")
462
+ for shot in args.shots:
463
+ scores = []
464
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
465
+ vqa_score = evaluate_vqa(
466
+ model=flamingo,
467
+ tokenizer=tokenizer,
468
+ image_processor=image_processor,
469
+ batch_size=args.batch_size,
470
+ image_dir_path=args.vqav2_image_dir_path,
471
+ questions_json_path=args.vqav2_questions_json_path,
472
+ annotations_json_path=args.vqav2_annotations_json_path,
473
+ vqa_dataset="vqa",
474
+ vis_embed_size=vis_embed_size,
475
+ rank=args.rank,
476
+ world_size=args.world_size,
477
+ id=args.id,
478
+ )
479
+ results["vqav2"].append(
480
+ {"shots": shot, "score": vqa_score}
481
+ )
482
+
483
+ if args.eval_gqa:
484
+ print("Evaluating on GQA...")
485
+ for shot in args.shots:
486
+ scores = []
487
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
488
+ vqa_score = evaluate_vqa(
489
+ model=flamingo,
490
+ tokenizer=tokenizer,
491
+ image_processor=image_processor,
492
+ batch_size=args.batch_size,
493
+ vqa_dataset="gqa",
494
+ vis_embed_size=vis_embed_size,
495
+ rank=args.rank,
496
+ world_size=args.world_size,
497
+ id=args.id,
498
+ )
499
+ results["gqa"].append(
500
+ {"shots": shot, "score": vqa_score}
501
+ )
502
+
503
+ if args.eval_imagenet:
504
+ print("Evaluating on ImageNet...")
505
+ for shot in args.shots:
506
+ scores = []
507
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
508
+ imagenet_score = evaluate_imagenet(
509
+ model=flamingo,
510
+ tokenizer=tokenizer,
511
+ image_processor=image_processor,
512
+ batch_size=args.batch_size,
513
+ num_samples=args.num_samples,
514
+ num_shots=shot,
515
+ device=args.device,
516
+ seed=seed,
517
+ imagenet_root=args.imagenet_root,
518
+ )
519
+ print(
520
+ f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
521
+ )
522
+ scores.append(imagenet_score)
523
+ print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
524
+ results["imagenet"].append(
525
+ {"shots": shot, "trials": scores, "mean": np.mean(scores)}
526
+ )
527
+
528
+ if args.eval_refcoco:
529
+ print("Evaluating on RefCOCO...")
530
+ refcoco_score = evaluate_refcoco(
531
+ model=flamingo,
532
+ tokenizer=tokenizer,
533
+ image_processor=image_processor,
534
+ batch_size=args.batch_size,
535
+ device=args.device,
536
+ tsvfile=args.refcoco_tsvfile,
537
+ vis_embed_size=vis_embed_size,
538
+ rank=args.rank,
539
+ world_size=args.world_size,
540
+ id=args.id,
541
+ )
542
+ results["refcoco"].append(
543
+ {"score": refcoco_score}
544
+ )
545
+
546
+ def prepare_batch_images(batch, image_processor):
547
+ batch_images = None
548
+ for b in batch:
549
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
550
+ if batch_images is None:
551
+ batch_images = b_image
552
+ else:
553
+ batch_images = torch.cat([batch_images, b_image], dim=0)
554
+ return batch_images
555
+
556
+ def get_outputs(
557
+ model,
558
+ batch_images,
559
+ attention_mask,
560
+ max_generation_length,
561
+ min_generation_length,
562
+ num_beams,
563
+ length_penalty,
564
+ input_ids,
565
+ image_start_index_list=None,
566
+ image_nums=None,
567
+ bad_words_ids=None,
568
+ ):
569
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
570
+ outputs = model.generate(
571
+ batch_images,
572
+ input_ids,
573
+ attention_mask=attention_mask,
574
+ max_new_tokens=max_generation_length,
575
+ min_length=min_generation_length,
576
+ num_beams=num_beams,
577
+ length_penalty=length_penalty,
578
+ image_start_index_list=image_start_index_list,
579
+ image_nums=image_nums,
580
+ bad_words_ids=bad_words_ids,
581
+ )
582
+
583
+ outputs = outputs[:, len(input_ids[0]) :]
584
+ return outputs
585
+
586
+
587
+ def evaluate_coco_flickr(
588
+ model,
589
+ tokenizer,
590
+ image_processor,
591
+ batch_size,
592
+ image_dir_path,
593
+ annotations_json_path,
594
+ seed=42,
595
+ max_generation_length=20,
596
+ num_beams=1,
597
+ length_penalty=-2.0,
598
+ device=-1,
599
+ is_flickr=False,
600
+ vis_embed_size=None,
601
+ rank=0,
602
+ world_size=1,
603
+ id=0,
604
+ ):
605
+ """Evaluate a model on COCO dataset.
606
+
607
+ Args:
608
+ model (nn.Module): model to evaluate
609
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
610
+ image_processor : image processor for the model
611
+ batch_size (int): batch size
612
+ image_dir_path (str, optional): path to the directory containing the images.
613
+ annotations_json_path (str, optional): path to the json file containing the annotations.
614
+ seed (int, optional): seed for random number generator. Defaults to 42.
615
+ max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
616
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
617
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
618
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
619
+ query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
620
+ num_shots (int, optional): number of in-context samples to use. Defaults to 8.
621
+ device (int, optional): device to use. Defaults to -1.
622
+ num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
623
+ is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
624
+
625
+ Returns:
626
+ float: CIDEr score
627
+
628
+ """
629
+ # eval_dataset = COCOFlickrDataset(
630
+ # image_dir_path=image_dir_path,
631
+ # annotations_path=annotations_json_path,
632
+ # is_flickr=is_flickr,
633
+ # )
634
+ coco_dataset = load_dataset("coco_caption")
635
+ eval_dataset = coco_dataset["test"]
636
+
637
+
638
+ model.eval().cuda()
639
+ predictions = defaultdict()
640
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
641
+ # if "peft" in lang_encoder_name:
642
+ # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
643
+ try:
644
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
645
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
646
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
647
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
648
+ except:
649
+ pass
650
+
651
+ def get_prompt(sample):
652
+ return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
653
+
654
+ tokenizer.padding_side = "left"
655
+ cnt = 0
656
+ if world_size > 1:
657
+ torch.distributed.barrier()
658
+ desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
659
+ for ii, batch in enumerate(more_itertools.chunked(
660
+ tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
661
+ )):
662
+ if ii % world_size != rank:
663
+ continue
664
+ cnt += len(batch)
665
+ batch_images = prepare_batch_images(
666
+ batch=batch,
667
+ image_processor=image_processor,
668
+ ).cuda()
669
+ batch_text = [get_prompt(s) for s in batch]
670
+ encodings = tokenizer(
671
+ batch_text,
672
+ padding="longest",
673
+ truncation=True,
674
+ return_tensors="pt",
675
+ max_length=2000,
676
+ )
677
+ input_ids = encodings["input_ids"].cuda()
678
+ attention_mask = encodings["attention_mask"].cuda()
679
+ skip_special_tokens = False
680
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
681
+ if rank == 0:
682
+ tqdm.write("use legacy model")
683
+ skip_special_tokens = True
684
+ for i in range(len(input_ids)):
685
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
686
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
687
+ input_ids[i, media_token_index - 1] = media_token_id
688
+ input_ids[i, media_token_index] = pad_token_id
689
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
690
+ input_ids[i, endofmedia_token_index] = bos_token_id
691
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
692
+ image_start_index_list = [[x] for x in image_start_index_list]
693
+ image_nums = [1] * len(input_ids)
694
+ if "llama" in lang_encoder_name:
695
+ attention_mask[input_ids == 0] = 0
696
+ outputs = get_outputs(
697
+ model=model,
698
+ batch_images=batch_images,
699
+ attention_mask=attention_mask,
700
+ max_generation_length=30,
701
+ min_generation_length=8,
702
+ num_beams=5,
703
+ length_penalty=0,
704
+ input_ids=input_ids,
705
+ image_start_index_list=image_start_index_list,
706
+ image_nums=image_nums,
707
+ )
708
+ new_predictions = [
709
+ postprocess_captioning_generation(out).replace('"', "")
710
+ for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
711
+ ]
712
+ # if rank == 0:
713
+ # tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
714
+
715
+ for i, sample in enumerate(batch):
716
+ predictions[int(sample["image_id"])] = {
717
+ "caption": new_predictions[i],
718
+ }
719
+ results_path = (
720
+ f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
721
+ if is_flickr
722
+ else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
723
+ )
724
+ with open(results_path, "w") as f:
725
+ f.write(
726
+ json.dumps(
727
+ [
728
+ {"image_id": k, "caption": predictions[k]["caption"]}
729
+ for k in predictions
730
+ ],
731
+ indent=2,
732
+ )
733
+ )
734
+ print("save to", results_path)
735
+ del predictions
736
+ time.sleep(10)
737
+ if world_size > 1:
738
+ torch.distributed.barrier()
739
+ if rank == 0:
740
+ print(f"evaluate on rank {rank}. world size is {world_size}")
741
+ predictions = []
742
+ for rank_i in range(world_size):
743
+ part_results_path = (
744
+ f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
745
+ if is_flickr
746
+ else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
747
+ )
748
+ print("load", part_results_path)
749
+ predictions.extend(json.load(open(part_results_path)))
750
+ os.remove(part_results_path)
751
+ print("num:", len(predictions))
752
+ results_path = (
753
+ f"flickrresults_{lang_encoder_name}.json"
754
+ if is_flickr
755
+ else f"cocoresults_{lang_encoder_name}.json"
756
+ )
757
+ json.dump(predictions, open(results_path, "w"), indent=2)
758
+
759
+ metrics = compute_cider(
760
+ result_path=results_path,
761
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
762
+ )
763
+ os.makedirs("eval_results", exist_ok=True)
764
+ acc = metrics["CIDEr"]
765
+ with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
766
+ f.write(json.dumps(predictions, indent=2))
767
+
768
+ # delete the temporary file
769
+ os.remove(results_path)
770
+ else:
771
+ metrics = {}
772
+ metrics["CIDEr"] = 0.0
773
+
774
+ return metrics["CIDEr"]
775
+
776
+
777
+ def evaluate_vqa(
778
+ model,
779
+ tokenizer,
780
+ image_processor,
781
+ batch_size,
782
+ image_dir_path=None,
783
+ questions_json_path=None,
784
+ annotations_json_path=None,
785
+ vqa_dataset="vqa",
786
+ vis_embed_size=None,
787
+ rank=0,
788
+ world_size=1,
789
+ id=0,
790
+ ):
791
+ """
792
+ Evaluate a model on VQA datasets. Currently supports VQA v2.0.
793
+
794
+ Args:
795
+ model (nn.Module): model to evaluate
796
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
797
+ image_processor : image processor for the model
798
+ batch_size (int): batch size
799
+ image_dir_path (str): path to image directory
800
+ questions_json_path (str): path to questions json file
801
+ annotations_json_path (str): path to annotations json file
802
+ seed (int, optional): random seed. Defaults to 42.
803
+ max_generation_length (int, optional): max generation length. Defaults to 5.
804
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
805
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
806
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
807
+ query_set_size (int, optional): size of the query set. Defaults to 2048.
808
+ num_shots (int, optional): number of shots to use. Defaults to 8.
809
+ device (int, optional): device to use. Defaults to -1 (cpu).
810
+ num_workers (int, optional): number of workers to use. Defaults to 4.
811
+ vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
812
+ Returns:
813
+ float: accuracy score
814
+ """
815
+ if world_size > 1:
816
+ torch.distributed.barrier()
817
+ if vqa_dataset == "gqa":
818
+ eval_dataset = GQADataset()
819
+ else:
820
+ eval_dataset = VQADataset(
821
+ image_dir_path=image_dir_path,
822
+ question_path=questions_json_path,
823
+ annotations_path=annotations_json_path,
824
+ vqa_dataset=vqa_dataset,
825
+ )
826
+ postprocessor = OKVQAPostProcess()
827
+ try:
828
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
829
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
830
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
831
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
832
+ except:
833
+ pass
834
+ def get_prompt(sample):
835
+ return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
836
+ # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
837
+
838
+ model.eval().cuda()
839
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
840
+ if "peft" in lang_encoder_name:
841
+ lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
842
+ predictions = []
843
+ tokenizer.padding_side = "left"
844
+ if world_size > 1:
845
+ torch.distributed.barrier()
846
+ for ii, batch in enumerate(more_itertools.chunked(
847
+ tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
848
+ )):
849
+ if ii % world_size != rank:
850
+ continue
851
+ batch_images = prepare_batch_images(
852
+ batch=batch,
853
+ image_processor=image_processor,
854
+ ).cuda()
855
+ batch_text = [get_prompt(s) for s in batch]
856
+ encodings = tokenizer(
857
+ batch_text,
858
+ return_tensors="pt",
859
+ padding="longest",
860
+ truncation=True,
861
+ max_length=2000,
862
+ )
863
+ input_ids = encodings["input_ids"].cuda()
864
+ attention_mask = encodings["attention_mask"].cuda()
865
+ skip_special_tokens = True
866
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
867
+ if rank == 0:
868
+ tqdm.write("use legacy model")
869
+ for i in range(len(input_ids)):
870
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
871
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
872
+ input_ids[i, media_token_index - 1] = media_token_id
873
+ input_ids[i, media_token_index] = pad_token_id
874
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
875
+ input_ids[i, endofmedia_token_index] = bos_token_id
876
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
877
+ image_start_index_list = [[x] for x in image_start_index_list]
878
+ image_nums = [1] * len(input_ids)
879
+ if "llama" in lang_encoder_name:
880
+ attention_mask[input_ids == 0] = 0
881
+ outputs = get_outputs(
882
+ model=model,
883
+ batch_images=batch_images,
884
+ attention_mask=attention_mask,
885
+ max_generation_length=10,
886
+ min_generation_length=1,
887
+ num_beams=5,
888
+ length_penalty=0,
889
+ input_ids=input_ids,
890
+ image_start_index_list=image_start_index_list,
891
+ image_nums=image_nums,
892
+ )
893
+ # postprocess begin
894
+ new_predictions = [
895
+ out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
896
+ ]
897
+ if vqa_dataset == "ok_vqa":
898
+ new_predictions = postprocessor._lemmatize(new_predictions)
899
+ if model.special:
900
+ for i in range(len(new_predictions)):
901
+ for answer, _ in Counter(batch[i]['answers']).most_common():
902
+ if answer in new_predictions[i]:
903
+ new_predictions[i] = answer
904
+ break
905
+ if "cant" in new_predictions[i] and "no" == answer:
906
+ new_predictions[i] = answer
907
+ break
908
+ if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
909
+ new_predictions[i] = answer
910
+ break
911
+
912
+ # if rank == 0:
913
+ # tqdm.write(f"{image_nums} {image_start_index_list}")
914
+ # for i in range(1):
915
+ # tqdm.write(f"ID: {batch[i]['question_id']} | gt QA: {batch[i]['question']} {Counter(batch[i]['answers']).most_common()}")
916
+ # tqdm.write("prompt: " + tokenizer.decode(input_ids[i]))
917
+ # tqdm.write("model output: " + new_predictions[i])
918
+
919
+ predictions.extend(
920
+ [
921
+ {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
922
+ for p, sample in zip(new_predictions, batch)
923
+ ]
924
+ )
925
+ with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
926
+ f.write(json.dumps(predictions))
927
+ print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
928
+
929
+ time.sleep(10)
930
+ if world_size > 1:
931
+ torch.distributed.barrier()
932
+ if rank == 0:
933
+ print(f"evaluate on rank {rank}. world size is {world_size}")
934
+ predictions = []
935
+ for rank_i in range(world_size):
936
+ print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
937
+ predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
938
+ os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
939
+ print("num:", len(predictions))
940
+ # save the predictions to a temporary file
941
+ random_uuid = str(uuid.uuid4())
942
+ with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
943
+ f.write(json.dumps(predictions, indent=4))
944
+
945
+ if vqa_dataset == "gqa":
946
+ acc = compute_gqa_accuracy(predictions)
947
+ else:
948
+ acc = compute_vqa_accuracy(
949
+ f"{vqa_dataset}results_{random_uuid}.json",
950
+ questions_json_path,
951
+ annotations_json_path,
952
+ vqa_dataset=vqa_dataset,
953
+ )
954
+ print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
955
+ os.makedirs("eval_results", exist_ok=True)
956
+ with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
957
+ f.write(json.dumps(predictions, indent=2))
958
+
959
+ # delete the temporary file
960
+ os.remove(f"{vqa_dataset}results_{random_uuid}.json")
961
+ else:
962
+ time.sleep(5)
963
+ acc = 0.0
964
+ if world_size > 1:
965
+ torch.distributed.barrier()
966
+ return acc
967
+
968
+
969
+ def evaluate_refcoco(
970
+ model,
971
+ tokenizer,
972
+ image_processor,
973
+ batch_size,
974
+ tsvfile,
975
+ max_generation_length=20,
976
+ num_beams=3,
977
+ length_penalty=-2.0,
978
+ device=-1,
979
+ vis_embed_size=None,
980
+ rank=0,
981
+ world_size=1,
982
+ id=0,
983
+ ):
984
+ model.eval().cuda()
985
+ loc_token_ids = []
986
+ for i in range(1000):
987
+ loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
988
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
989
+ total = 0
990
+ correct = 0
991
+ ious = []
992
+ if "refcocog" in tsvfile:
993
+ dataset_name = "refcocog"
994
+ elif "refcocoplus" in tsvfile:
995
+ dataset_name = "refcocoplus"
996
+ else:
997
+ dataset_name = "refcoco"
998
+ with open(tsvfile, "r") as f:
999
+ lines = f.readlines()
1000
+ pbar = tqdm(lines, disable=(rank != 0))
1001
+ for ii, line in enumerate(pbar):
1002
+ if ii % world_size != rank:
1003
+ continue
1004
+ total += 1
1005
+ line = line.rstrip()
1006
+ uniq_id, image_id, text, region_coord, image = line.split("\t")
1007
+
1008
+ # image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
1009
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
1010
+ # image2 = Image.open("yolo.png").convert("RGB")
1011
+ # image1 = image1.resize((224, 224))
1012
+ # image2 = image2.resize((224, 224))
1013
+ # images = [image1, image2]
1014
+
1015
+ # gt_box = np.array(list(map(float, region_coord.split(","))))
1016
+ # width = image.width
1017
+ # height = image.height
1018
+ # gt_box /= np.array([width, height, width, height])
1019
+ # batch_images = [image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) for image in images]
1020
+ # batch_images = torch.cat(batch_images, dim=0)
1021
+ # image = Image.open("yolo_test.png").convert("RGB")
1022
+ image = Image.open("example.png").convert("RGB")
1023
+ image = image.resize((224, 224))
1024
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1025
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text.rstrip('.')}<|#visual#|>"]
1026
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#endofattr#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
1027
+ # prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
1028
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
1029
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
1030
+
1031
+
1032
+ encodings = tokenizer(
1033
+ prompt,
1034
+ padding="longest",
1035
+ truncation=True,
1036
+ return_tensors="pt",
1037
+ max_length=2000,
1038
+ )
1039
+ input_ids = encodings["input_ids"]
1040
+ attention_mask = encodings["attention_mask"]
1041
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1042
+ image_start_index_list = [image_start_index_list]
1043
+ image_nums = [1]
1044
+ vision_x = batch_images.cuda()
1045
+ lang_x = input_ids.cuda()
1046
+ attention_mask = attention_mask.cuda()
1047
+ print(image_start_index_list, image_nums)
1048
+
1049
+ model.debug_id = 0
1050
+ # outputs = get_outputs(
1051
+ # model=model,
1052
+ # batch_images=vision_x,
1053
+ # attention_mask=attention_mask,
1054
+ # max_generation_length=20,
1055
+ # min_generation_length=8,
1056
+ # num_beams=5,
1057
+ # length_penalty=0,
1058
+ # input_ids=lang_x,
1059
+ # image_start_index_list=image_start_index_list,
1060
+ # image_nums=image_nums,
1061
+ # )
1062
+ # print(tokenizer.decode(outputs[0]))
1063
+ # exit()
1064
+
1065
+ prebox = [93, 20, 155, 172] # man
1066
+ # prebox = [32, 82, 89, 213] # dog
1067
+ # prebox = [34, 49, 166, 164] # bike
1068
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
1069
+ outputs = model(
1070
+ vision_x=vision_x,
1071
+ lang_x=lang_x,
1072
+ attention_mask=attention_mask,
1073
+ labels=None,
1074
+ image_nums=image_nums,
1075
+ image_start_index_list=image_start_index_list,
1076
+ added_bbox_list=[torch.tensor(prebox).cuda().unsqueeze(0) / 224],
1077
+ add_box=True,
1078
+ debug_mode=True,
1079
+ )
1080
+
1081
+ boxes = outputs["boxes"]
1082
+ scores = outputs["scores"]
1083
+ box = boxes[scores.argmax()]
1084
+ open_cv_image = np.array(image)
1085
+ # Convert RGB to BGR
1086
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
1087
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
1088
+ open_cv_image = cv2.rectangle(open_cv_image, prebox[:2], prebox[2:], (0, 0, 255), 2)
1089
+ cv2.imwrite(f"output2.jpg", open_cv_image)
1090
+ print(box)
1091
+ print(prebox)
1092
+ exit()
1093
+
1094
+ # force_words = ["man", "table"]
1095
+ # force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
1096
+
1097
+
1098
+ # sequences, hidden_states_for_each_step = get_outputs(
1099
+ # model=model,
1100
+ # batch_images=vision_x,
1101
+ # attention_mask=attention_mask,
1102
+ # max_generation_length=20,
1103
+ # min_generation_length=8,
1104
+ # num_beams=5,
1105
+ # length_penalty=0,
1106
+ # input_ids=lang_x,
1107
+ # image_start_index_list=image_start_index_list,
1108
+ # image_nums=image_nums,
1109
+ # force_words_ids=force_words_ids,
1110
+ # )
1111
+ # sequence = sequences[0]
1112
+ # print(tokenizer.decode(sequence))
1113
+ # for i, token in enumerate(sequence):
1114
+ # if token == model.visual_token_id:
1115
+ # print(tokenizer.decode(sequence[:i+1]))
1116
+ # if hasattr(model, "debug_id"):
1117
+ # model.debug_id += 1
1118
+ # else:
1119
+ # model.debug_id = 0
1120
+ # this_lang_x = torch.hstack([lang_x[0], sequence[:i+1]]).unsqueeze(0)
1121
+ # this_attention_mask = torch.ones_like(this_lang_x).cuda()
1122
+ # with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1123
+ # _ = model(
1124
+ # vision_x=vision_x,
1125
+ # lang_x=this_lang_x,
1126
+ # attention_mask=this_attention_mask,
1127
+ # labels=None,
1128
+ # image_nums=image_nums,
1129
+ # image_start_index_list=image_start_index_list,
1130
+ # added_bbox_list=None,
1131
+ # )
1132
+ # exit()
1133
+
1134
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1135
+ f.write(json.dumps([total, correct]))
1136
+ if world_size > 1:
1137
+ torch.distributed.barrier()
1138
+ if rank == 0:
1139
+ total = 0
1140
+ correct = 0
1141
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1142
+ for rank_i in range(world_size):
1143
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1144
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1145
+ total += total_part
1146
+ correct += correct_part
1147
+ score = correct / total
1148
+ print("score:", score)
1149
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1150
+ pass
1151
+ else:
1152
+ score = 0.0
1153
+ if world_size > 1:
1154
+ torch.distributed.barrier()
1155
+ return score
1156
+
1157
+
1158
+ if __name__ == "__main__":
1159
+ main()
multimodal/build/lib/open_flamingo/eval/evaluate_find_showcase.py ADDED
@@ -0,0 +1,1700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from math import ceil
4
+ import os
5
+ import random
6
+ import uuid
7
+ from collections import defaultdict
8
+ from typing import Callable
9
+ import time
10
+ import cv2
11
+ import webdataset as wds
12
+ from sklearn.metrics import recall_score, average_precision_score
13
+
14
+ import more_itertools
15
+ import numpy as np
16
+ import torch
17
+ from coco_metric import compute_cider, postprocess_captioning_generation
18
+ from eval_datasets import VQADataset
19
+ from tqdm import tqdm
20
+ from collections import Counter
21
+
22
+ from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
23
+ from open_flamingo.eval.classification import (
24
+ compute_per_sample_probs,
25
+ compute_per_sample_loss,
26
+ )
27
+ from open_flamingo.eval.imagenet_utils import (
28
+ openai_imagenet_classnames,
29
+ IMAGENET_1K_CLASS_ID_TO_LABEL,
30
+ )
31
+
32
+ from open_flamingo.src.factory import create_model_and_transforms
33
+ from PIL import Image
34
+ from io import BytesIO
35
+ import base64
36
+ from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
37
+ import string
38
+ from lavis.datasets.builders import load_dataset
39
+ from open_flamingo.eval.task.reg import evaluate_reg
40
+ from open_flamingo.eval.task.gqa import GQADataset
41
+ from open_flamingo.eval.task.vl_checklist import evaluate_vlc
42
+ from open_flamingo.eval.task.crepe import evaluate_crepe
43
+
44
+ def get_iou(box1, box2):
45
+ # box1 and box2 should be in the format [x1, y1, x2, y2]
46
+ intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
47
+ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
48
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
49
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
50
+ union = area_box1 + area_box2 - intersection
51
+ iou = intersection / union if union > 0 else 0
52
+ return iou
53
+
54
+ def expand2square(pil_img, background_color):
55
+ width, height = pil_img.size
56
+ if width == height:
57
+ return pil_img
58
+ elif width > height:
59
+ result = Image.new(pil_img.mode, (width, width), background_color)
60
+ result.paste(pil_img, (0, (width - height) // 2))
61
+ return result
62
+ else:
63
+ result = Image.new(pil_img.mode, (height, height), background_color)
64
+ result.paste(pil_img, ((height - width) // 2, 0))
65
+ return result
66
+
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
69
+ parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
70
+ parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
71
+ parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
72
+ parser.add_argument("--checkpoint_path", type=str, required=True)
73
+ parser.add_argument(
74
+ "--results_file", type=str, default=None, help="JSON file to save results"
75
+ )
76
+
77
+ # Trial arguments
78
+ parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
79
+ parser.add_argument(
80
+ "--num_trials",
81
+ type=int,
82
+ default=1,
83
+ help="Number of trials to run for each shot using different demonstrations",
84
+ )
85
+ parser.add_argument(
86
+ "--trial_seeds",
87
+ nargs="+",
88
+ default=[0],
89
+ help="Seeds to use for each trial for picking demonstrations and eval sets",
90
+ )
91
+ parser.add_argument(
92
+ "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
93
+ )
94
+
95
+ parser.add_argument("--batch_size", type=int, default=8)
96
+
97
+ # Per-dataset evaluation flags
98
+ parser.add_argument(
99
+ "--eval_coco",
100
+ action="store_true",
101
+ default=False,
102
+ help="Whether to evaluate on COCO.",
103
+ )
104
+ parser.add_argument(
105
+ "--eval_vqav2",
106
+ action="store_true",
107
+ default=False,
108
+ help="Whether to evaluate on VQAV2.",
109
+ )
110
+ parser.add_argument(
111
+ "--eval_ok_vqa",
112
+ action="store_true",
113
+ default=False,
114
+ help="Whether to evaluate on OK-VQA.",
115
+ )
116
+ parser.add_argument(
117
+ "--eval_imagenet",
118
+ action="store_true",
119
+ default=False,
120
+ help="Whether to evaluate on ImageNet.",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--eval_flickr30",
125
+ action="store_true",
126
+ default=False,
127
+ help="Whether to evaluate on Flickr30.",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--eval_refcoco",
132
+ action="store_true",
133
+ default=False,
134
+ help="Whether to evaluate on RefCOCO.",
135
+ )
136
+
137
+ # Dataset arguments
138
+
139
+ ## Flickr30 Dataset
140
+ parser.add_argument(
141
+ "--flickr_image_dir_path",
142
+ type=str,
143
+ help="Path to the flickr30/flickr30k_images directory.",
144
+ default=None,
145
+ )
146
+ parser.add_argument(
147
+ "--flickr_annotations_json_path",
148
+ type=str,
149
+ help="Path to the dataset_flickr30k_coco_style.json file.",
150
+ default=None,
151
+ )
152
+
153
+ ## COCO Dataset
154
+ parser.add_argument(
155
+ "--coco_image_dir_path",
156
+ type=str,
157
+ help="Path to the flickr30/flickr30k_images directory.",
158
+ default=None,
159
+ )
160
+ parser.add_argument(
161
+ "--coco_annotations_json_path",
162
+ type=str,
163
+ default=None,
164
+ )
165
+
166
+ ## VQAV2 Dataset
167
+ parser.add_argument(
168
+ "--vqav2_image_dir_path",
169
+ type=str,
170
+ default=None,
171
+ )
172
+ parser.add_argument(
173
+ "--vqav2_questions_json_path",
174
+ type=str,
175
+ default=None,
176
+ )
177
+ parser.add_argument(
178
+ "--vqav2_annotations_json_path",
179
+ type=str,
180
+ default=None,
181
+ )
182
+
183
+ ## OK-VQA Dataset
184
+ parser.add_argument(
185
+ "--ok_vqa_image_dir_path",
186
+ type=str,
187
+ help="Path to the vqav2/train2014 directory.",
188
+ default=None,
189
+ )
190
+ parser.add_argument(
191
+ "--ok_vqa_questions_json_path",
192
+ type=str,
193
+ help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
194
+ default=None,
195
+ )
196
+ parser.add_argument(
197
+ "--ok_vqa_annotations_json_path",
198
+ type=str,
199
+ help="Path to the v2_mscoco_train2014_annotations.json file.",
200
+ default=None,
201
+ )
202
+
203
+ ## Imagenet dataset
204
+ parser.add_argument("--imagenet_root", type=str, default="/tmp")
205
+
206
+ ## RefCOCO dataset
207
+ parser.add_argument("--refcoco_tsvfile", type=str, default=None)
208
+
209
+ parser.add_argument(
210
+ "--location_token_num",
211
+ default=1000,
212
+ type=int,
213
+ )
214
+ # distributed training
215
+ parser.add_argument(
216
+ "--dist-url",
217
+ default="env://",
218
+ type=str,
219
+ help="url used to set up distributed training",
220
+ )
221
+ parser.add_argument(
222
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
223
+ )
224
+ parser.add_argument(
225
+ "--horovod",
226
+ default=False,
227
+ action="store_true",
228
+ help="Use horovod for distributed training.",
229
+ )
230
+ parser.add_argument(
231
+ "--no-set-device-rank",
232
+ default=False,
233
+ action="store_true",
234
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
235
+ )
236
+ parser.add_argument(
237
+ "--dist",
238
+ default=False,
239
+ action="store_true",
240
+ )
241
+ parser.add_argument(
242
+ "--lora",
243
+ default=False,
244
+ action="store_true",
245
+ )
246
+ parser.add_argument(
247
+ "--lora_r",
248
+ default=16,
249
+ type=int,
250
+ required=False,
251
+ )
252
+ parser.add_argument(
253
+ "--legacy",
254
+ default=False,
255
+ action="store_true",
256
+ )
257
+ parser.add_argument(
258
+ "--special",
259
+ default=False,
260
+ action="store_true",
261
+ )
262
+ parser.add_argument(
263
+ "--id",
264
+ default=0,
265
+ type=int,
266
+ required=False,
267
+ )
268
+
269
+ parser.add_argument(
270
+ "--eval_gqa",
271
+ default=False,
272
+ action="store_true",
273
+ )
274
+ parser.add_argument(
275
+ "--use_sam",
276
+ default=None,
277
+ type=str,
278
+ required=False,
279
+ )
280
+ parser.add_argument(
281
+ "--add_visual_token",
282
+ default=False,
283
+ action="store_true",
284
+ )
285
+ parser.add_argument(
286
+ "--use_format_v2",
287
+ default=False,
288
+ action="store_true",
289
+ )
290
+ parser.add_argument(
291
+ "--eval_aro",
292
+ default=False,
293
+ action="store_true",
294
+ )
295
+ parser.add_argument(
296
+ "--eval_pisc",
297
+ default=False,
298
+ action="store_true",
299
+ )
300
+ parser.add_argument(
301
+ "--eval_reg",
302
+ default=False,
303
+ action="store_true",
304
+ )
305
+ parser.add_argument(
306
+ "--eval_vlc",
307
+ default=False,
308
+ action="store_true",
309
+ )
310
+ parser.add_argument(
311
+ "--eval_crepe",
312
+ default=False,
313
+ action="store_true",
314
+ )
315
+ parser.add_argument(
316
+ "--level",
317
+ default=4,
318
+ type=int,
319
+ )
320
+ parser.add_argument(
321
+ "--type",
322
+ default="swap",
323
+ type=str,
324
+ )
325
+
326
+
327
+ class OKVQAPostProcess():
328
+ def __init__(self):
329
+ self._lemmatizer = None
330
+
331
+ def _lemmatize(self, answers):
332
+ def apply(answer):
333
+ doc = self.lemmatizer(answer)
334
+
335
+ words = []
336
+ for token in doc:
337
+ if token.pos_ in ["NOUN", "VERB"]:
338
+ words.append(token.lemma_)
339
+ else:
340
+ words.append(token.text)
341
+ answer = " ".join(words)
342
+
343
+ return answer
344
+
345
+ return [apply(answer) for answer in answers]
346
+
347
+ @property
348
+ def lemmatizer(self):
349
+ if self._lemmatizer is None:
350
+ try:
351
+ import spacy
352
+
353
+ self._lemmatizer = spacy.load("en_core_web_sm")
354
+ except ImportError:
355
+ logging.error(
356
+ """
357
+ Please install spacy and en_core_web_sm model to apply lemmatization.
358
+ python -m spacy download en_core_web_sm
359
+ OR
360
+ import spacy.cli
361
+ spacy.cli.download("en_core_web_sm")
362
+ """
363
+ )
364
+ exit(1)
365
+
366
+ return self._lemmatizer
367
+
368
+
369
+ def main():
370
+ args = parser.parse_args()
371
+ if args.dist:
372
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
373
+ print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
374
+ device_id = init_distributed_device(args)
375
+ else:
376
+ args.rank = 0
377
+ args.world_size = 1
378
+ print(f"rank: {args.rank} world_size: {args.world_size}")
379
+
380
+ if "sam" in args.checkpoint_path:
381
+ args.use_sam = "vit_l"
382
+
383
+ args.add_visual_token = True
384
+ if "lora" in args.checkpoint_path:
385
+ args.lora = True
386
+
387
+
388
+ args.add_pe = False
389
+ args.add_box = True
390
+ args.relation = False
391
+ args.enhance_data = False
392
+ args.use_format_v2 = True
393
+
394
+
395
+
396
+ import hashlib
397
+ args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
398
+
399
+ # load model
400
+ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
401
+ args.vision_encoder_path,
402
+ args.vision_encoder_pretrained,
403
+ args.lm_path,
404
+ args.lm_tokenizer_path,
405
+ location_token_num=args.location_token_num,
406
+ lora=args.lora,
407
+ lora_r=16,
408
+ use_sam=args.use_sam,
409
+ add_visual_token=args.add_visual_token,
410
+ use_format_v2=args.use_format_v2,
411
+ add_box=args.add_box,
412
+ add_pe=args.add_pe,
413
+ add_relation=args.relation,
414
+ enhance_data=args.enhance_data,
415
+ )
416
+ flamingo.use_format_v2 = args.use_format_v2
417
+ if args.special:
418
+ flamingo.special = True
419
+ else:
420
+ flamingo.special = False
421
+ if args.legacy:
422
+ flamingo.legacy = True
423
+ print("use legacy evaluation")
424
+ flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
425
+ flamingo.expr_name = args.checkpoint_path.split("/")[-2]
426
+ if args.rank == 0:
427
+ print("legacy", True if hasattr(flamingo, "legacy") else False)
428
+ print("step:", flamingo.step_num)
429
+ print("expr:", flamingo.expr_name)
430
+ print("use format v2:", flamingo.use_format_v2)
431
+ print(args)
432
+ checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
433
+ model_state_dict = {}
434
+ for key in checkpoint["model_state_dict"].keys():
435
+ model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
436
+ if "vision_encoder.logit_scale"in model_state_dict:
437
+ # previous checkpoint has some unnecessary weights
438
+ del model_state_dict["vision_encoder.logit_scale"]
439
+ del model_state_dict["vision_encoder.visual.proj"]
440
+ del model_state_dict["vision_encoder.visual.ln_post.weight"]
441
+ del model_state_dict["vision_encoder.visual.ln_post.bias"]
442
+ flamingo.load_state_dict(model_state_dict, strict=True)
443
+ results = defaultdict(list)
444
+ if args.eval_coco:
445
+ print("Evaluating on COCO...")
446
+ for shot in args.shots:
447
+ scores = []
448
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
449
+ cider_score = evaluate_coco_flickr(
450
+ model=flamingo,
451
+ tokenizer=tokenizer,
452
+ image_processor=image_processor,
453
+ batch_size=args.batch_size,
454
+ image_dir_path=args.coco_image_dir_path,
455
+ annotations_json_path=args.coco_annotations_json_path,
456
+ device=args.device,
457
+ seed=seed,
458
+ vis_embed_size=vis_embed_size,
459
+ rank=args.rank,
460
+ world_size=args.world_size,
461
+ id=args.id,
462
+ )
463
+ print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
464
+ scores.append(cider_score)
465
+ print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
466
+ results["coco"].append(
467
+ {"shots": shot, "trials": scores, "mean": np.mean(scores)}
468
+ )
469
+
470
+ if args.eval_ok_vqa:
471
+ print("Evaluating on OK-VQA...")
472
+ for shot in args.shots:
473
+ scores = []
474
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
475
+ ok_vqa_score = evaluate_vqa(
476
+ model=flamingo,
477
+ tokenizer=tokenizer,
478
+ image_processor=image_processor,
479
+ batch_size=args.batch_size,
480
+ image_dir_path=args.ok_vqa_image_dir_path,
481
+ questions_json_path=args.ok_vqa_questions_json_path,
482
+ annotations_json_path=args.ok_vqa_annotations_json_path,
483
+ vqa_dataset="ok_vqa",
484
+ vis_embed_size=vis_embed_size,
485
+ rank=args.rank,
486
+ world_size=args.world_size,
487
+ id=args.id,
488
+ )
489
+ results["ok_vqa"].append(
490
+ {"shots": shot, "score": ok_vqa_score}
491
+ )
492
+
493
+ if args.eval_vqav2:
494
+ print("Evaluating on VQAv2...")
495
+ for shot in args.shots:
496
+ scores = []
497
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
498
+ vqa_score = evaluate_vqa(
499
+ model=flamingo,
500
+ tokenizer=tokenizer,
501
+ image_processor=image_processor,
502
+ batch_size=args.batch_size,
503
+ image_dir_path=args.vqav2_image_dir_path,
504
+ questions_json_path=args.vqav2_questions_json_path,
505
+ annotations_json_path=args.vqav2_annotations_json_path,
506
+ vqa_dataset="vqa",
507
+ vis_embed_size=vis_embed_size,
508
+ rank=args.rank,
509
+ world_size=args.world_size,
510
+ id=args.id,
511
+ )
512
+ results["vqav2"].append(
513
+ {"shots": shot, "score": vqa_score}
514
+ )
515
+
516
+ if args.eval_gqa:
517
+ print("Evaluating on GQA...")
518
+ for shot in args.shots:
519
+ scores = []
520
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
521
+ vqa_score = evaluate_vqa(
522
+ model=flamingo,
523
+ tokenizer=tokenizer,
524
+ image_processor=image_processor,
525
+ batch_size=args.batch_size,
526
+ vqa_dataset="gqa",
527
+ vis_embed_size=vis_embed_size,
528
+ rank=args.rank,
529
+ world_size=args.world_size,
530
+ id=args.id,
531
+ )
532
+ results["gqa"].append(
533
+ {"shots": shot, "score": vqa_score}
534
+ )
535
+
536
+ if args.eval_refcoco:
537
+ print("Evaluating on RefCOCO...")
538
+ refcoco_score = evaluate_refcoco(
539
+ model=flamingo,
540
+ tokenizer=tokenizer,
541
+ image_processor=image_processor,
542
+ batch_size=args.batch_size,
543
+ device=args.device,
544
+ tsvfile=args.refcoco_tsvfile,
545
+ vis_embed_size=vis_embed_size,
546
+ rank=args.rank,
547
+ world_size=args.world_size,
548
+ id=args.id,
549
+ )
550
+ results["refcoco"].append(
551
+ {"score": refcoco_score}
552
+ )
553
+ if args.eval_aro:
554
+ print("Evaluating on ARO...")
555
+ aro_score = evaluate_aro(
556
+ model=flamingo,
557
+ tokenizer=tokenizer,
558
+ image_processor=image_processor,
559
+ batch_size=args.batch_size,
560
+ device=args.device,
561
+ tsvfile=args.refcoco_tsvfile,
562
+ vis_embed_size=vis_embed_size,
563
+ rank=args.rank,
564
+ world_size=args.world_size,
565
+ id=args.id,
566
+ add_relation=args.relation,
567
+ )
568
+ results["aro"].append(
569
+ {"score": aro_score}
570
+ )
571
+ if args.eval_pisc:
572
+ print("Evaluating on ARO...")
573
+ aro_score = evaluate_pisc(
574
+ model=flamingo,
575
+ tokenizer=tokenizer,
576
+ image_processor=image_processor,
577
+ batch_size=args.batch_size,
578
+ device=args.device,
579
+ tsvfile=args.refcoco_tsvfile,
580
+ vis_embed_size=vis_embed_size,
581
+ rank=args.rank,
582
+ world_size=args.world_size,
583
+ id=args.id,
584
+ )
585
+ results["pisc"].append(
586
+ {"score": aro_score}
587
+ )
588
+ if args.eval_reg:
589
+ print("Evaluating on Referring Expression Generation...")
590
+ cider = evaluate_reg(
591
+ model=flamingo,
592
+ tokenizer=tokenizer,
593
+ image_processor=image_processor,
594
+ vis_embed_size=vis_embed_size,
595
+ rank=args.rank,
596
+ world_size=args.world_size,
597
+ id=args.id,
598
+ )
599
+ results["reg"].append(
600
+ {"score": cider}
601
+ )
602
+ if args.eval_vlc:
603
+ print("Evaluating on VL-checklist...")
604
+ vlc_score = evaluate_vlc(
605
+ model=flamingo,
606
+ tokenizer=tokenizer,
607
+ image_processor=image_processor,
608
+ vis_embed_size=vis_embed_size,
609
+ rank=args.rank,
610
+ world_size=args.world_size,
611
+ id=args.id,
612
+ )
613
+ results["vlc"].append(
614
+ {"score": vlc_score}
615
+ )
616
+ if args.eval_crepe:
617
+ print("Evaluating on CREPE...")
618
+ crepe_score = evaluate_crepe(
619
+ model=flamingo,
620
+ tokenizer=tokenizer,
621
+ image_processor=image_processor,
622
+ vis_embed_size=vis_embed_size,
623
+ rank=args.rank,
624
+ world_size=args.world_size,
625
+ id=args.id,
626
+ level=args.level,
627
+ type=args.type,
628
+ )
629
+ results["crepe"].append(
630
+ {"score": crepe_score}
631
+ )
632
+
633
+ def prepare_batch_images(batch, image_processor):
634
+ batch_images = None
635
+ for b in batch:
636
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
637
+ if batch_images is None:
638
+ batch_images = b_image
639
+ else:
640
+ batch_images = torch.cat([batch_images, b_image], dim=0)
641
+ return batch_images
642
+
643
+ def get_outputs(
644
+ model,
645
+ batch_images,
646
+ attention_mask,
647
+ max_generation_length,
648
+ min_generation_length,
649
+ num_beams,
650
+ length_penalty,
651
+ input_ids,
652
+ image_start_index_list=None,
653
+ image_nums=None,
654
+ bad_words_ids=None,
655
+ ):
656
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
657
+ outputs = model.generate(
658
+ batch_images,
659
+ input_ids,
660
+ attention_mask=attention_mask,
661
+ max_new_tokens=max_generation_length,
662
+ min_length=min_generation_length,
663
+ num_beams=num_beams,
664
+ length_penalty=length_penalty,
665
+ image_start_index_list=image_start_index_list,
666
+ image_nums=image_nums,
667
+ bad_words_ids=bad_words_ids,
668
+ )
669
+
670
+ outputs = outputs[:, len(input_ids[0]) :]
671
+ return outputs
672
+
673
+
674
+ def evaluate_coco_flickr(
675
+ model,
676
+ tokenizer,
677
+ image_processor,
678
+ batch_size,
679
+ image_dir_path,
680
+ annotations_json_path,
681
+ seed=42,
682
+ max_generation_length=20,
683
+ num_beams=1,
684
+ length_penalty=-2.0,
685
+ device=-1,
686
+ is_flickr=False,
687
+ vis_embed_size=None,
688
+ rank=0,
689
+ world_size=1,
690
+ id=0,
691
+ ):
692
+ """Evaluate a model on COCO dataset.
693
+
694
+ Args:
695
+ model (nn.Module): model to evaluate
696
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
697
+ image_processor : image processor for the model
698
+ batch_size (int): batch size
699
+ image_dir_path (str, optional): path to the directory containing the images.
700
+ annotations_json_path (str, optional): path to the json file containing the annotations.
701
+ seed (int, optional): seed for random number generator. Defaults to 42.
702
+ max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
703
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
704
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
705
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
706
+ query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
707
+ num_shots (int, optional): number of in-context samples to use. Defaults to 8.
708
+ device (int, optional): device to use. Defaults to -1.
709
+ num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
710
+ is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
711
+
712
+ Returns:
713
+ float: CIDEr score
714
+
715
+ """
716
+ # eval_dataset = COCOFlickrDataset(
717
+ # image_dir_path=image_dir_path,
718
+ # annotations_path=annotations_json_path,
719
+ # is_flickr=is_flickr,
720
+ # )
721
+ coco_dataset = load_dataset("coco_caption")
722
+ eval_dataset = coco_dataset["test"]
723
+
724
+
725
+ model.eval().cuda()
726
+ predictions = defaultdict()
727
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
728
+ # if "peft" in lang_encoder_name:
729
+ # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
730
+ try:
731
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
732
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
733
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
734
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
735
+ except:
736
+ pass
737
+
738
+ def get_prompt(sample):
739
+ return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
740
+
741
+ tokenizer.padding_side = "left"
742
+ cnt = 0
743
+ if world_size > 1:
744
+ torch.distributed.barrier()
745
+ desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
746
+ for ii, batch in enumerate(more_itertools.chunked(
747
+ tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
748
+ )):
749
+ if ii % world_size != rank:
750
+ continue
751
+ cnt += len(batch)
752
+ batch_images = prepare_batch_images(
753
+ batch=batch,
754
+ image_processor=image_processor,
755
+ ).cuda()
756
+ batch_text = [get_prompt(s) for s in batch]
757
+ encodings = tokenizer(
758
+ batch_text,
759
+ padding="longest",
760
+ truncation=True,
761
+ return_tensors="pt",
762
+ max_length=2000,
763
+ )
764
+ input_ids = encodings["input_ids"].cuda()
765
+ attention_mask = encodings["attention_mask"].cuda()
766
+ skip_special_tokens = False
767
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
768
+ if rank == 0:
769
+ tqdm.write("use legacy model")
770
+ skip_special_tokens = True
771
+ for i in range(len(input_ids)):
772
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
773
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
774
+ input_ids[i, media_token_index - 1] = media_token_id
775
+ input_ids[i, media_token_index] = pad_token_id
776
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
777
+ input_ids[i, endofmedia_token_index] = bos_token_id
778
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
779
+ image_start_index_list = [[x] for x in image_start_index_list]
780
+ image_nums = [1] * len(input_ids)
781
+ if "llama" in lang_encoder_name:
782
+ attention_mask[input_ids == 0] = 0
783
+ outputs = get_outputs(
784
+ model=model,
785
+ batch_images=batch_images,
786
+ attention_mask=attention_mask,
787
+ max_generation_length=30,
788
+ min_generation_length=8,
789
+ num_beams=5,
790
+ length_penalty=0,
791
+ input_ids=input_ids,
792
+ image_start_index_list=image_start_index_list,
793
+ image_nums=image_nums,
794
+ )
795
+ new_predictions = [
796
+ postprocess_captioning_generation(out).replace('"', "")
797
+ for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
798
+ ]
799
+ # if rank == 0:
800
+ # tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
801
+
802
+ for i, sample in enumerate(batch):
803
+ predictions[int(sample["image_id"])] = {
804
+ "caption": new_predictions[i],
805
+ }
806
+ results_path = (
807
+ f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
808
+ if is_flickr
809
+ else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
810
+ )
811
+ with open(results_path, "w") as f:
812
+ f.write(
813
+ json.dumps(
814
+ [
815
+ {"image_id": k, "caption": predictions[k]["caption"]}
816
+ for k in predictions
817
+ ],
818
+ indent=2,
819
+ )
820
+ )
821
+ print("save to", results_path)
822
+ del predictions
823
+ time.sleep(10)
824
+ if world_size > 1:
825
+ torch.distributed.barrier()
826
+ if rank == 0:
827
+ print(f"evaluate on rank {rank}. world size is {world_size}")
828
+ predictions = []
829
+ for rank_i in range(world_size):
830
+ part_results_path = (
831
+ f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
832
+ if is_flickr
833
+ else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
834
+ )
835
+ print("load", part_results_path)
836
+ predictions.extend(json.load(open(part_results_path)))
837
+ os.remove(part_results_path)
838
+ print("num:", len(predictions))
839
+ results_path = (
840
+ f"flickrresults_{lang_encoder_name}.json"
841
+ if is_flickr
842
+ else f"cocoresults_{lang_encoder_name}.json"
843
+ )
844
+ json.dump(predictions, open(results_path, "w"), indent=2)
845
+
846
+ metrics = compute_cider(
847
+ result_path=results_path,
848
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
849
+ )
850
+ os.makedirs("eval_results", exist_ok=True)
851
+ acc = metrics["CIDEr"]
852
+ with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
853
+ f.write(json.dumps(predictions, indent=2))
854
+
855
+ # delete the temporary file
856
+ os.remove(results_path)
857
+ else:
858
+ metrics = {}
859
+ metrics["CIDEr"] = 0.0
860
+
861
+ return metrics["CIDEr"]
862
+
863
+
864
+ def evaluate_vqa(
865
+ model,
866
+ tokenizer,
867
+ image_processor,
868
+ batch_size,
869
+ image_dir_path=None,
870
+ questions_json_path=None,
871
+ annotations_json_path=None,
872
+ vqa_dataset="vqa",
873
+ vis_embed_size=None,
874
+ rank=0,
875
+ world_size=1,
876
+ id=0,
877
+ ):
878
+ """
879
+ Evaluate a model on VQA datasets. Currently supports VQA v2.0.
880
+
881
+ Args:
882
+ model (nn.Module): model to evaluate
883
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
884
+ image_processor : image processor for the model
885
+ batch_size (int): batch size
886
+ image_dir_path (str): path to image directory
887
+ questions_json_path (str): path to questions json file
888
+ annotations_json_path (str): path to annotations json file
889
+ seed (int, optional): random seed. Defaults to 42.
890
+ max_generation_length (int, optional): max generation length. Defaults to 5.
891
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
892
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
893
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
894
+ query_set_size (int, optional): size of the query set. Defaults to 2048.
895
+ num_shots (int, optional): number of shots to use. Defaults to 8.
896
+ device (int, optional): device to use. Defaults to -1 (cpu).
897
+ num_workers (int, optional): number of workers to use. Defaults to 4.
898
+ vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
899
+ Returns:
900
+ float: accuracy score
901
+ """
902
+ if world_size > 1:
903
+ torch.distributed.barrier()
904
+ if vqa_dataset == "gqa":
905
+ eval_dataset = GQADataset()
906
+ else:
907
+ eval_dataset = VQADataset(
908
+ image_dir_path=image_dir_path,
909
+ question_path=questions_json_path,
910
+ annotations_path=annotations_json_path,
911
+ vqa_dataset=vqa_dataset,
912
+ )
913
+ postprocessor = OKVQAPostProcess()
914
+ try:
915
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
916
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
917
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
918
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
919
+ except:
920
+ pass
921
+ def get_prompt(sample):
922
+ return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
923
+ # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
924
+
925
+ model.eval().cuda()
926
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
927
+ if "peft" in lang_encoder_name:
928
+ lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
929
+ predictions = []
930
+ tokenizer.padding_side = "left"
931
+ if world_size > 1:
932
+ torch.distributed.barrier()
933
+ this_tot = 0
934
+ for ii, batch in enumerate(more_itertools.chunked(
935
+ tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
936
+ )):
937
+ if ii % world_size != rank:
938
+ continue
939
+ batch_images = prepare_batch_images(
940
+ batch=batch,
941
+ image_processor=image_processor,
942
+ ).cuda()
943
+ batch_text = [get_prompt(s) for s in batch]
944
+ encodings = tokenizer(
945
+ batch_text,
946
+ return_tensors="pt",
947
+ padding="longest",
948
+ truncation=True,
949
+ max_length=2000,
950
+ )
951
+ input_ids = encodings["input_ids"].cuda()
952
+ attention_mask = encodings["attention_mask"].cuda()
953
+ skip_special_tokens = True
954
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
955
+ if rank == 0:
956
+ tqdm.write("use legacy model")
957
+ for i in range(len(input_ids)):
958
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
959
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
960
+ input_ids[i, media_token_index - 1] = media_token_id
961
+ input_ids[i, media_token_index] = pad_token_id
962
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
963
+ input_ids[i, endofmedia_token_index] = bos_token_id
964
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
965
+ image_start_index_list = [[x] for x in image_start_index_list]
966
+ image_nums = [1] * len(input_ids)
967
+ if "llama" in lang_encoder_name:
968
+ attention_mask[input_ids == 0] = 0
969
+ outputs = get_outputs(
970
+ model=model,
971
+ batch_images=batch_images,
972
+ attention_mask=attention_mask,
973
+ max_generation_length=10,
974
+ min_generation_length=1,
975
+ num_beams=5,
976
+ length_penalty=0,
977
+ input_ids=input_ids,
978
+ image_start_index_list=image_start_index_list,
979
+ image_nums=image_nums,
980
+ )
981
+ # postprocess begin
982
+ new_predictions = [
983
+ out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
984
+ ]
985
+ if vqa_dataset == "ok_vqa":
986
+ new_predictions = postprocessor._lemmatize(new_predictions)
987
+ if model.special:
988
+ for i in range(len(new_predictions)):
989
+ for answer, _ in Counter(batch[i]['answers']).most_common():
990
+ if answer in new_predictions[i]:
991
+ new_predictions[i] = answer
992
+ break
993
+ if "cant" in new_predictions[i] and "no" == answer:
994
+ new_predictions[i] = answer
995
+ break
996
+ if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
997
+ new_predictions[i] = answer
998
+ break
999
+
1000
+ this_tot += 1
1001
+ if rank == 0 and this_tot % 20 == 0:
1002
+ for i in range(1):
1003
+ tqdm.write("model output: " + new_predictions[i])
1004
+
1005
+ predictions.extend(
1006
+ [
1007
+ {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
1008
+ for p, sample in zip(new_predictions, batch)
1009
+ ]
1010
+ )
1011
+ with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
1012
+ f.write(json.dumps(predictions))
1013
+ print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
1014
+
1015
+ time.sleep(10)
1016
+ if world_size > 1:
1017
+ torch.distributed.barrier()
1018
+ if rank == 0:
1019
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1020
+ predictions = []
1021
+ for rank_i in range(world_size):
1022
+ print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
1023
+ predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
1024
+ os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
1025
+ print("num:", len(predictions))
1026
+ # save the predictions to a temporary file
1027
+ random_uuid = str(uuid.uuid4())
1028
+ with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
1029
+ f.write(json.dumps(predictions, indent=4))
1030
+
1031
+ if vqa_dataset == "gqa":
1032
+ acc = compute_gqa_accuracy(predictions)
1033
+ else:
1034
+ acc = compute_vqa_accuracy(
1035
+ f"{vqa_dataset}results_{random_uuid}.json",
1036
+ questions_json_path,
1037
+ annotations_json_path,
1038
+ vqa_dataset=vqa_dataset,
1039
+ )
1040
+ print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
1041
+ os.makedirs("eval_results", exist_ok=True)
1042
+ with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
1043
+ f.write(json.dumps(predictions, indent=2))
1044
+
1045
+ # delete the temporary file
1046
+ os.remove(f"{vqa_dataset}results_{random_uuid}.json")
1047
+ else:
1048
+ time.sleep(5)
1049
+ acc = 0.0
1050
+ if world_size > 1:
1051
+ torch.distributed.barrier()
1052
+ return acc
1053
+
1054
+
1055
+ def evaluate_refcoco(
1056
+ model,
1057
+ tokenizer,
1058
+ image_processor,
1059
+ batch_size,
1060
+ tsvfile,
1061
+ max_generation_length=20,
1062
+ num_beams=3,
1063
+ length_penalty=-2.0,
1064
+ device=-1,
1065
+ vis_embed_size=None,
1066
+ rank=0,
1067
+ world_size=1,
1068
+ id=0,
1069
+ ):
1070
+ model.eval().cuda()
1071
+ loc_token_ids = []
1072
+ for i in range(1000):
1073
+ loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
1074
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1075
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1076
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
1077
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
1078
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
1079
+ # all_ids = set(range(model.lang_encoder.lm_head.out_features))
1080
+ # bad_words_ids = list(all_ids - set(loc_token_ids))
1081
+ # bad_words_ids = [[b] for b in bad_words_ids]
1082
+ # min_loc_token_id = min(loc_token_ids)
1083
+ # max_loc_token_id = max(loc_token_ids)
1084
+ total = 0
1085
+ correct = 0
1086
+ ious = []
1087
+ if "refcocog" in tsvfile:
1088
+ dataset_name = "refcocog"
1089
+ elif "refcocoplus" in tsvfile:
1090
+ dataset_name = "refcocoplus"
1091
+ else:
1092
+ dataset_name = "refcoco"
1093
+ with open(tsvfile, "r") as f:
1094
+ lines = f.readlines()
1095
+ pbar = tqdm(lines, disable=(rank != 0))
1096
+ for ii, line in enumerate(pbar):
1097
+ if ii % world_size != rank:
1098
+ continue
1099
+ total += 1
1100
+ line = line.rstrip()
1101
+ uniq_id, image_id, text, region_coord, image = line.split("\t")
1102
+
1103
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
1104
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
1105
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
1106
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
1107
+
1108
+ gt_box = np.array(list(map(float, region_coord.split(","))))
1109
+ width = image.width
1110
+ height = image.height
1111
+ image = image.resize((224, 224))
1112
+ gt_box = gt_box / np.array([width, height, width, height]) * 224
1113
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1114
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
1115
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
1116
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
1117
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
1118
+
1119
+
1120
+ encodings = tokenizer(
1121
+ prompt,
1122
+ padding="longest",
1123
+ truncation=True,
1124
+ return_tensors="pt",
1125
+ max_length=2000,
1126
+ )
1127
+ input_ids = encodings["input_ids"]
1128
+ attention_mask = encodings["attention_mask"]
1129
+ # attention_mask[input_ids == prebox_token_id] = 0
1130
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1131
+ image_start_index_list = [[x] for x in image_start_index_list]
1132
+ image_nums = [1] * len(input_ids)
1133
+ vision_x = batch_images.cuda()
1134
+ lang_x = input_ids.cuda()
1135
+ attention_mask = attention_mask.cuda()
1136
+
1137
+ model.debug_id = 0
1138
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
1139
+ outputs = model(
1140
+ vision_x=vision_x,
1141
+ lang_x=lang_x,
1142
+ attention_mask=attention_mask,
1143
+ labels=None,
1144
+ image_nums=image_nums,
1145
+ image_start_index_list=image_start_index_list,
1146
+ added_bbox_list=None,
1147
+ add_box=False,
1148
+ )
1149
+ boxes = outputs["boxes"]
1150
+ scores = outputs["scores"]
1151
+ if len(scores) > 0:
1152
+ box = boxes[scores.argmax()]
1153
+ iou = get_iou(box, gt_box)
1154
+ else:
1155
+ iou = 0.0
1156
+ # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
1157
+ tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
1158
+ if iou >= 0.5:
1159
+ correct += 1
1160
+ pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
1161
+ # open_cv_image = np.array(image)
1162
+ # # Convert RGB to BGR
1163
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1164
+ # for box, score in zip(boxes, scores):
1165
+ # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
1166
+ # cv2.imwrite("output.jpg", open_cv_image)
1167
+ # print(boxes)
1168
+ # print(scores)
1169
+ # exit()
1170
+
1171
+
1172
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1173
+ f.write(json.dumps([total, correct]))
1174
+ if world_size > 1:
1175
+ torch.distributed.barrier()
1176
+ if rank == 0:
1177
+ total = 0
1178
+ correct = 0
1179
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1180
+ for rank_i in range(world_size):
1181
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1182
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1183
+ total += total_part
1184
+ correct += correct_part
1185
+ score = correct / total
1186
+ print("score:", score)
1187
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1188
+ pass
1189
+ else:
1190
+ score = 0.0
1191
+ if world_size > 1:
1192
+ torch.distributed.barrier()
1193
+ return score
1194
+
1195
+
1196
+
1197
+ # def preprocess_visual_info(Text):
1198
+ # text = Text.split(" ")
1199
+ # for is_idx, t in enumerate(text):
1200
+ # if t == "is":
1201
+ # break
1202
+ # the_idx = is_idx
1203
+ # while text[the_idx] != "the":
1204
+ # the_idx -= 1
1205
+ # obj_A = " ".join(text[the_idx+1:is_idx])
1206
+ # second_the_idx = len(text) - 1
1207
+ # while text[second_the_idx] != "the":
1208
+ # second_the_idx -= 1
1209
+ # obj_B = " ".join(text[second_the_idx+1:])
1210
+ # visual_obj_A = f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
1211
+ # visual_obj_B = f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
1212
+ # Text = Text.replace(obj_A, f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
1213
+ # Text = Text.replace(obj_B, f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
1214
+ # return Text, obj_A, obj_B, visual_obj_A, visual_obj_B
1215
+
1216
+
1217
+ def preprocess_visual_info(Text):
1218
+ text = Text.split(" ")
1219
+ for is_idx, t in enumerate(text):
1220
+ if t == "is":
1221
+ break
1222
+ the_idx = is_idx
1223
+ while text[the_idx] != "the":
1224
+ the_idx -= 1
1225
+ obj_A = " ".join(text[the_idx+1:is_idx])
1226
+ second_the_idx = len(text) - 1
1227
+ while text[second_the_idx] != "the":
1228
+ second_the_idx -= 1
1229
+ obj_B = " ".join(text[second_the_idx+1:])
1230
+ relation = " ".join(text[is_idx+1:second_the_idx])
1231
+ visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
1232
+ visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
1233
+ Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
1234
+ return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
1235
+
1236
+
1237
+
1238
+
1239
+ def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
1240
+ assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
1241
+ encodings = tokenizer(
1242
+ prompt,
1243
+ padding="longest",
1244
+ truncation=True,
1245
+ return_tensors="pt",
1246
+ max_length=2000,
1247
+ )
1248
+ input_ids = encodings["input_ids"]
1249
+ attention_mask = encodings["attention_mask"]
1250
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1251
+ image_start_index_list = [[x] for x in image_start_index_list]
1252
+ image_nums = [1] * len(input_ids)
1253
+ vision_x = batch_images.cuda()
1254
+ lang_x = input_ids.cuda()
1255
+ attention_mask = attention_mask.cuda()
1256
+
1257
+ model.debug_id = 0
1258
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
1259
+ outputs = model(
1260
+ vision_x=vision_x,
1261
+ lang_x=lang_x,
1262
+ attention_mask=attention_mask,
1263
+ labels=None,
1264
+ image_nums=image_nums,
1265
+ image_start_index_list=image_start_index_list,
1266
+ added_bbox_list=visual_box_list,
1267
+ add_box=visual_box_list is not None,
1268
+ relations=None,
1269
+ debug_mode=False,
1270
+ )
1271
+ boxes = outputs["boxes"]
1272
+ scores = outputs["scores"]
1273
+ if debug:
1274
+ import pdb; pdb.set_trace()
1275
+ if return_all:
1276
+ return boxes, scores
1277
+ if len(scores) == 0:
1278
+ return None, None
1279
+ else:
1280
+ return boxes[scores.argmax()], scores.max()
1281
+
1282
+
1283
+ def evaluate_aro(
1284
+ model,
1285
+ tokenizer,
1286
+ image_processor,
1287
+ batch_size,
1288
+ tsvfile,
1289
+ max_generation_length=20,
1290
+ num_beams=3,
1291
+ length_penalty=-2.0,
1292
+ device=-1,
1293
+ vis_embed_size=None,
1294
+ rank=0,
1295
+ world_size=1,
1296
+ id=0,
1297
+ add_visual=True,
1298
+ add_relation=False,
1299
+ subset=False,
1300
+ choose_left_right=True,
1301
+ ):
1302
+ both_failed_ids = json.load(open("both_failed_ids.json"))
1303
+ os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
1304
+ # from groundingdino.demo.caption_grounder import caption_grounder
1305
+ # generator = caption_grounder(
1306
+ # config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
1307
+ # checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
1308
+ # cpu_only=False,
1309
+ # box_threshold=0.1, text_threshold=0.1,
1310
+ # )
1311
+ dataset_name = "aro"
1312
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1313
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1314
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1315
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1316
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1317
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1318
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
1319
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
1320
+ model.eval().cuda()
1321
+ total = 0
1322
+ correct = 0
1323
+ from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
1324
+ vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
1325
+ with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f:
1326
+ all_labels = json.load(f)
1327
+ label_ids = tokenizer(all_labels).input_ids
1328
+ label_ids = sorted(list(set([x[0] for x in label_ids])))
1329
+
1330
+ if subset:
1331
+ subset_idx = json.load(open("aro_subset.json"))
1332
+ pbar = tqdm(subset_idx, disable=(rank != 0))
1333
+ else:
1334
+ pbar = tqdm(vgr_dataset, disable=(rank != 0))
1335
+ for ii, sample in enumerate(pbar):
1336
+ if subset:
1337
+ ORI_IDX = int(sample)
1338
+ sample = vgr_dataset[sample]
1339
+ # if ORI_IDX != 19036:
1340
+ # continue
1341
+ if ii % world_size != rank:
1342
+ continue
1343
+
1344
+ # not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
1345
+ # if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
1346
+ # if rank == 0:
1347
+ # tqdm.write(f"SKIP: {sample['caption_options'][1]}")
1348
+ # continue
1349
+ total += 1
1350
+ # image = sample["image_options"][0]
1351
+ image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/man_on_hydrant.png").convert("RGB")
1352
+ image = image.resize((224, 224))
1353
+
1354
+ # text = sample["caption_options"][1] # 1 is true caption
1355
+ text = "the man is sitting on the fire hydrant"
1356
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1357
+ text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
1358
+
1359
+
1360
+ first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
1361
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
1362
+ first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
1363
+
1364
+
1365
+ # use grounding DINO to get the first bbox
1366
+ # caption = f"{obj_A}"
1367
+ # with torch.no_grad():
1368
+ # logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
1369
+ # boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
1370
+ # objects = {}
1371
+ # for box, phrase in zip(boxes_filt, pred_phrases):
1372
+ # obj, score = phrase
1373
+ # obj = obj[0]
1374
+ # if obj not in objects:
1375
+ # objects[obj] = (score, box)
1376
+ # if objects[obj][0] < score:
1377
+ # objects[obj] = (score, box)
1378
+ # try:
1379
+ # first_box = objects[obj_A][1].clone()
1380
+ # first_box[:2] -= first_box[2:] / 2
1381
+ # first_box[2:] += first_box[:2]
1382
+ # first_box = first_box.clamp(0, 0.99) * 224.0
1383
+ # first_box = first_box.numpy()
1384
+ # first_score = objects[obj_A][0]
1385
+ # except:
1386
+ # first_box = None
1387
+
1388
+ if first_box is None:
1389
+ text_A = "the " + obj_A
1390
+ added_bbox_list = None
1391
+ else:
1392
+ text_A = visual_obj_A
1393
+ added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
1394
+
1395
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
1396
+ pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
1397
+ prebox_token_id, return_all=True)
1398
+
1399
+
1400
+ # open_cv_image = np.array(image)
1401
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1402
+ # for box, score in zip(pre_box, pre_score):
1403
+ # print(box, score)
1404
+ # if score > 0.1:
1405
+ # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2)
1406
+ # cv2.imwrite(f"test1.jpg", open_cv_image)
1407
+ # print(sample["caption_options"][idx])
1408
+ # exit()
1409
+
1410
+
1411
+
1412
+ if pre_boxes is None:
1413
+ pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
1414
+ pre_scores = [1.0]
1415
+
1416
+ rank_list = []
1417
+ # pre_boxes = [pre_boxes[0]]
1418
+ # pre_scores = [pre_scores[0]]
1419
+ for pre_box, pre_score in zip(pre_boxes, pre_scores):
1420
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
1421
+
1422
+ encodings = tokenizer(
1423
+ prompt,
1424
+ padding="longest",
1425
+ truncation=True,
1426
+ return_tensors="pt",
1427
+ max_length=512,
1428
+ )
1429
+ input_ids = encodings["input_ids"]
1430
+ attention_mask = encodings["attention_mask"]
1431
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1432
+ image_start_index_list = [[x] for x in image_start_index_list]
1433
+ image_nums = [1] * len(input_ids)
1434
+ vision_x = batch_images.cuda()
1435
+ lang_x = input_ids.cuda()
1436
+ attention_mask = attention_mask.cuda()
1437
+ labels = lang_x.clone()
1438
+
1439
+ answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1
1440
+ # pre_box = None
1441
+ labels[0, :answer_start_idx] = -100
1442
+ # # labels[labels == endofobject_token_id] = -100
1443
+ # labels[:, 0] = -100
1444
+ # labels[labels == visual_token_id] = -100
1445
+ # labels[labels == box_token_id] = -100
1446
+ # labels[labels == previsual_token_id] = -100
1447
+ # labels[labels == prebox_token_id] = -100
1448
+ # labels[labels == endofattr_token_id] = -100
1449
+ # labels[labels == tokenizer.pad_token_id] = -100
1450
+ # labels[labels == media_token_id] = -100
1451
+ # labels[labels == endofmedia_token_id] = -100
1452
+ answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"]
1453
+ labels[input_ids == visual_token_id] = -100
1454
+ labels[input_ids == box_token_id] = -100
1455
+ labels[input_ids == endofattr_token_id] = -100
1456
+ labels[input_ids == previsual_token_id] = -100
1457
+ labels[input_ids == prebox_token_id] = -100
1458
+ labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
1459
+ labels[torch.roll(input_ids == box_token_id, 1)] = -100
1460
+ labels[:, 0] = -100
1461
+ labels[input_ids == tokenizer.pad_token_id] = -100
1462
+ labels[input_ids == media_token_id] = -100
1463
+ labels[input_ids == endofmedia_token_id] = -100
1464
+
1465
+ added_bbox_list = None
1466
+ if add_visual:
1467
+ added_bbox_list = []
1468
+ if first_box is not None:
1469
+ added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
1470
+ if pre_box is not None:
1471
+ added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
1472
+ if added_bbox_list is not None and len(added_bbox_list) == 0:
1473
+ added_bbox_list = None
1474
+
1475
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1476
+ outputs = model(
1477
+ vision_x=vision_x,
1478
+ lang_x=lang_x,
1479
+ attention_mask=attention_mask,
1480
+ labels=labels,
1481
+ image_nums=image_nums,
1482
+ image_start_index_list=image_start_index_list,
1483
+ added_bbox_list=added_bbox_list,
1484
+ add_box=added_bbox_list is not None,
1485
+ relations=None,
1486
+ )
1487
+ logits = outputs["logits"][0, answer_start_idx:]
1488
+ # _rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0]))
1489
+ _rank = logits[0].sort(descending=True).indices.tolist().index(answer_ids[0])
1490
+ print(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10]))
1491
+ print(tokenizer.decode(logits[1].sort(descending=True).indices.tolist()[:10]))
1492
+ rank_list.append(_rank)
1493
+ # open_cv_image = np.array(image)
1494
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1495
+ # if first_box is not None:
1496
+ # open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
1497
+ # if pre_box is not None:
1498
+ # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
1499
+
1500
+ # font = cv2.FONT_HERSHEY_SIMPLEX
1501
+ # org = [10, 20]
1502
+ # fontScale = 0.5
1503
+ # color = (0, 0, 0)
1504
+ # thickness = 1
1505
+ # open_cv_image = cv2.resize(open_cv_image, (512, 512))
1506
+ # put_text = sample["caption_options"][1]
1507
+ # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1508
+ # org[1] += 20
1509
+ # put_text = "top10 in green box"
1510
+ # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1511
+ # fontScale = 1.0
1512
+ # thickness = 2
1513
+ # for ind in logits_list[i][0].sort(descending=True).indices[:10]:
1514
+ # org[1] += 20
1515
+ # put_text = f"{tokenizer.decode(ind)}"
1516
+ # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1517
+ # tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}")
1518
+ # tqdm.write(f"{rank_list}")
1519
+ final_rank = min(rank_list)
1520
+ if final_rank < 10:
1521
+ correct += 1
1522
+ TYPE = "CORRECT"
1523
+ # if ii in both_failed_ids:
1524
+ # tqdm.write(f"case find->{sample['caption_options'][1]}")
1525
+ # image.save(f"case_study/{ii}_{rank_list}_{sample['caption_options'][1]}.jpg")
1526
+ if rank == 0:
1527
+ tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
1528
+ else:
1529
+ TYPE = "WRONG"
1530
+ if rank == 0:
1531
+ tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
1532
+ # cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image)
1533
+ pbar.set_description(f"score: {correct / total:.4f} | {final_rank}")
1534
+
1535
+
1536
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1537
+ f.write(json.dumps([total, correct]))
1538
+ if world_size > 1:
1539
+ torch.distributed.barrier()
1540
+ if rank == 0:
1541
+ total = 0
1542
+ correct = 0
1543
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1544
+ for rank_i in range(world_size):
1545
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1546
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1547
+ total += total_part
1548
+ correct += correct_part
1549
+ score = correct / total
1550
+ print("score:", score, "total:", total)
1551
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1552
+ pass
1553
+ else:
1554
+ score = 0.0
1555
+ if world_size > 1:
1556
+ torch.distributed.barrier()
1557
+ return score
1558
+
1559
+
1560
+ def evaluate_pisc(
1561
+ model,
1562
+ tokenizer,
1563
+ image_processor,
1564
+ batch_size,
1565
+ tsvfile,
1566
+ max_generation_length=20,
1567
+ num_beams=3,
1568
+ length_penalty=-2.0,
1569
+ device=-1,
1570
+ vis_embed_size=None,
1571
+ rank=0,
1572
+ world_size=1,
1573
+ id=0,
1574
+ add_visual=True,
1575
+ ):
1576
+ from open_flamingo.train.instruction_template import PISC_TEMPLATES
1577
+ dataset_name = "pisc"
1578
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1579
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1580
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1581
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1582
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1583
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1584
+ model.train().cuda()
1585
+
1586
+ dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
1587
+ pbar = tqdm(dataset, disable=(rank != 0))
1588
+
1589
+ rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
1590
+ rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
1591
+ gt = []
1592
+ pred_scores = []
1593
+ for III, sample in enumerate(pbar):
1594
+ if III % world_size != rank:
1595
+ continue
1596
+ image_path, dataset, data = sample
1597
+ image = Image.open(image_path)
1598
+ size = image_processor.transforms[0].size
1599
+ image = image.resize((size, size))
1600
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1601
+ boxA = data[0]
1602
+ boxB = data[1]
1603
+ gt_relation = data[2]
1604
+ losses = []
1605
+ for i_rel, option_rel in enumerate(rel_id_to_type):
1606
+ text = PISC_TEMPLATES[0].format(relation=option_rel)
1607
+ added_bbox = [
1608
+ torch.tensor([boxA]).cuda(),
1609
+ torch.tensor([boxB]).cuda(),
1610
+ ]
1611
+ caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
1612
+ encodings = tokenizer(
1613
+ caption,
1614
+ padding="longest",
1615
+ truncation=True,
1616
+ return_tensors="pt",
1617
+ max_length=2000,
1618
+ )
1619
+ input_ids = encodings["input_ids"]
1620
+ attention_mask = encodings["attention_mask"]
1621
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1622
+ image_start_index_list = [[x] for x in image_start_index_list]
1623
+ image_nums = [1] * len(input_ids)
1624
+ vision_x = batch_images.cuda()
1625
+ lang_x = input_ids.cuda()
1626
+ attention_mask = attention_mask.cuda()
1627
+
1628
+ labels = lang_x.clone()
1629
+ labels[labels == tokenizer.pad_token_id] = -100
1630
+ if add_visual:
1631
+ # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
1632
+ # endofattr_next_token_index[1] += 1
1633
+ # endofattr_next_token_id = labels[endofattr_next_token_index]
1634
+ # </obj><visual><box></attr>NEXT_WORD
1635
+ # </obj> predict NEXT_WORD
1636
+ # <visual><box></attr> predict nothing
1637
+ labels[labels == visual_token_id] = -100
1638
+ labels[labels == box_token_id] = -100
1639
+ labels[labels == endofattr_token_id] = -100
1640
+ # labels[endofattr_next_token_index] = -100
1641
+ labels[:, 0] = -100
1642
+ answer_token_id = tokenizer(" Answer").input_ids[0]
1643
+ answer_token_loc = (input_ids == answer_token_id).nonzero()
1644
+ for batch_idx, idx in answer_token_loc:
1645
+ labels[batch_idx][:idx+2] = -100
1646
+
1647
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1648
+ outputs = model(
1649
+ vision_x=vision_x,
1650
+ lang_x=lang_x,
1651
+ attention_mask=attention_mask,
1652
+ labels=labels,
1653
+ image_nums=image_nums,
1654
+ image_start_index_list=image_start_index_list,
1655
+ added_bbox_list=added_bbox,
1656
+ add_box=added_bbox is not None,
1657
+ )
1658
+ loss_total = outputs.loss.reshape(labels.shape[0], -1)
1659
+ loss = loss_total.sum() / (loss_total != 0).sum()
1660
+ losses.append(loss.item())
1661
+ pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
1662
+ gt.append(rel_type_to_id[gt_relation])
1663
+ gt = np.array(gt)
1664
+ pred_scores = np.array(pred_scores)
1665
+ pred = pred_scores.argmax(1)
1666
+
1667
+
1668
+ print("total num:", len(gt))
1669
+ recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
1670
+ print("recalls:", recalls)
1671
+
1672
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1673
+ f.write(json.dumps([gt.tolist(), pred.tolist()]))
1674
+ if world_size > 1:
1675
+ torch.distributed.barrier()
1676
+ if rank == 0:
1677
+ gt = []
1678
+ pred = []
1679
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1680
+ for rank_i in range(world_size):
1681
+ [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1682
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1683
+ gt.extend(gt_part)
1684
+ pred.extend(pred_part)
1685
+ print("total num:", len(gt))
1686
+ recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
1687
+ print("recalls:", recalls)
1688
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
1689
+ f.write(f"{gt}\n")
1690
+ f.write(f"{pred}\n")
1691
+ f.write(f"{recalls}\n")
1692
+ score = 0.0
1693
+ if world_size > 1:
1694
+ torch.distributed.barrier()
1695
+ return score
1696
+
1697
+
1698
+
1699
+ if __name__ == "__main__":
1700
+ main()
multimodal/build/lib/open_flamingo/eval/evaluate_temp.py ADDED
@@ -0,0 +1,1838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from math import ceil
4
+ import os
5
+ import random
6
+ import uuid
7
+ from collections import defaultdict
8
+ from typing import Callable
9
+ import time
10
+ import cv2
11
+ import webdataset as wds
12
+ from sklearn.metrics import recall_score, average_precision_score
13
+
14
+ import more_itertools
15
+ import numpy as np
16
+ import torch
17
+ from coco_metric import compute_cider, postprocess_captioning_generation
18
+ from eval_datasets import VQADataset, GQADataset
19
+ from tqdm import tqdm
20
+ from collections import Counter
21
+
22
+ from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
23
+ from open_flamingo.eval.classification import (
24
+ compute_per_sample_probs,
25
+ compute_per_sample_loss,
26
+ )
27
+ from open_flamingo.eval.imagenet_utils import (
28
+ openai_imagenet_classnames,
29
+ IMAGENET_1K_CLASS_ID_TO_LABEL,
30
+ )
31
+
32
+ from open_flamingo.src.factory import create_model_and_transforms
33
+ from PIL import Image
34
+ from io import BytesIO
35
+ import base64
36
+ from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
37
+ import string
38
+ from lavis.datasets.builders import load_dataset
39
+
40
+
41
+ def get_iou(box1, box2):
42
+ # box1 and box2 should be in the format [x1, y1, x2, y2]
43
+ intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
44
+ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
45
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
46
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
47
+ union = area_box1 + area_box2 - intersection
48
+ iou = intersection / union if union > 0 else 0
49
+ return iou
50
+
51
+ def expand2square(pil_img, background_color):
52
+ width, height = pil_img.size
53
+ if width == height:
54
+ return pil_img
55
+ elif width > height:
56
+ result = Image.new(pil_img.mode, (width, width), background_color)
57
+ result.paste(pil_img, (0, (width - height) // 2))
58
+ return result
59
+ else:
60
+ result = Image.new(pil_img.mode, (height, height), background_color)
61
+ result.paste(pil_img, ((height - width) // 2, 0))
62
+ return result
63
+
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
66
+ parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
67
+ parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
68
+ parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
69
+ parser.add_argument("--checkpoint_path", type=str, required=True)
70
+ parser.add_argument(
71
+ "--results_file", type=str, default=None, help="JSON file to save results"
72
+ )
73
+
74
+ # Trial arguments
75
+ parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
76
+ parser.add_argument(
77
+ "--num_trials",
78
+ type=int,
79
+ default=1,
80
+ help="Number of trials to run for each shot using different demonstrations",
81
+ )
82
+ parser.add_argument(
83
+ "--trial_seeds",
84
+ nargs="+",
85
+ default=[0],
86
+ help="Seeds to use for each trial for picking demonstrations and eval sets",
87
+ )
88
+ parser.add_argument(
89
+ "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
90
+ )
91
+
92
+ parser.add_argument("--batch_size", type=int, default=8)
93
+
94
+ # Per-dataset evaluation flags
95
+ parser.add_argument(
96
+ "--eval_coco",
97
+ action="store_true",
98
+ default=False,
99
+ help="Whether to evaluate on COCO.",
100
+ )
101
+ parser.add_argument(
102
+ "--eval_vqav2",
103
+ action="store_true",
104
+ default=False,
105
+ help="Whether to evaluate on VQAV2.",
106
+ )
107
+ parser.add_argument(
108
+ "--eval_ok_vqa",
109
+ action="store_true",
110
+ default=False,
111
+ help="Whether to evaluate on OK-VQA.",
112
+ )
113
+ parser.add_argument(
114
+ "--eval_imagenet",
115
+ action="store_true",
116
+ default=False,
117
+ help="Whether to evaluate on ImageNet.",
118
+ )
119
+
120
+ parser.add_argument(
121
+ "--eval_flickr30",
122
+ action="store_true",
123
+ default=False,
124
+ help="Whether to evaluate on Flickr30.",
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--eval_refcoco",
129
+ action="store_true",
130
+ default=False,
131
+ help="Whether to evaluate on RefCOCO.",
132
+ )
133
+
134
+ # Dataset arguments
135
+
136
+ ## Flickr30 Dataset
137
+ parser.add_argument(
138
+ "--flickr_image_dir_path",
139
+ type=str,
140
+ help="Path to the flickr30/flickr30k_images directory.",
141
+ default=None,
142
+ )
143
+ parser.add_argument(
144
+ "--flickr_annotations_json_path",
145
+ type=str,
146
+ help="Path to the dataset_flickr30k_coco_style.json file.",
147
+ default=None,
148
+ )
149
+
150
+ ## COCO Dataset
151
+ parser.add_argument(
152
+ "--coco_image_dir_path",
153
+ type=str,
154
+ help="Path to the flickr30/flickr30k_images directory.",
155
+ default=None,
156
+ )
157
+ parser.add_argument(
158
+ "--coco_annotations_json_path",
159
+ type=str,
160
+ default=None,
161
+ )
162
+
163
+ ## VQAV2 Dataset
164
+ parser.add_argument(
165
+ "--vqav2_image_dir_path",
166
+ type=str,
167
+ default=None,
168
+ )
169
+ parser.add_argument(
170
+ "--vqav2_questions_json_path",
171
+ type=str,
172
+ default=None,
173
+ )
174
+ parser.add_argument(
175
+ "--vqav2_annotations_json_path",
176
+ type=str,
177
+ default=None,
178
+ )
179
+
180
+ ## OK-VQA Dataset
181
+ parser.add_argument(
182
+ "--ok_vqa_image_dir_path",
183
+ type=str,
184
+ help="Path to the vqav2/train2014 directory.",
185
+ default=None,
186
+ )
187
+ parser.add_argument(
188
+ "--ok_vqa_questions_json_path",
189
+ type=str,
190
+ help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
191
+ default=None,
192
+ )
193
+ parser.add_argument(
194
+ "--ok_vqa_annotations_json_path",
195
+ type=str,
196
+ help="Path to the v2_mscoco_train2014_annotations.json file.",
197
+ default=None,
198
+ )
199
+
200
+ ## Imagenet dataset
201
+ parser.add_argument("--imagenet_root", type=str, default="/tmp")
202
+
203
+ ## RefCOCO dataset
204
+ parser.add_argument("--refcoco_tsvfile", type=str, default=None)
205
+
206
+ parser.add_argument(
207
+ "--location_token_num",
208
+ default=1000,
209
+ type=int,
210
+ )
211
+ # distributed training
212
+ parser.add_argument(
213
+ "--dist-url",
214
+ default="env://",
215
+ type=str,
216
+ help="url used to set up distributed training",
217
+ )
218
+ parser.add_argument(
219
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
220
+ )
221
+ parser.add_argument(
222
+ "--horovod",
223
+ default=False,
224
+ action="store_true",
225
+ help="Use horovod for distributed training.",
226
+ )
227
+ parser.add_argument(
228
+ "--no-set-device-rank",
229
+ default=False,
230
+ action="store_true",
231
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
232
+ )
233
+ parser.add_argument(
234
+ "--dist",
235
+ default=False,
236
+ action="store_true",
237
+ )
238
+ parser.add_argument(
239
+ "--lora",
240
+ default=False,
241
+ action="store_true",
242
+ )
243
+ parser.add_argument(
244
+ "--lora_r",
245
+ default=16,
246
+ type=int,
247
+ required=False,
248
+ )
249
+ parser.add_argument(
250
+ "--legacy",
251
+ default=False,
252
+ action="store_true",
253
+ )
254
+ parser.add_argument(
255
+ "--special",
256
+ default=False,
257
+ action="store_true",
258
+ )
259
+ parser.add_argument(
260
+ "--id",
261
+ default=0,
262
+ type=int,
263
+ required=False,
264
+ )
265
+
266
+ parser.add_argument(
267
+ "--eval_gqa",
268
+ default=False,
269
+ action="store_true",
270
+ )
271
+ parser.add_argument(
272
+ "--use_sam",
273
+ default=None,
274
+ type=str,
275
+ required=False,
276
+ )
277
+ parser.add_argument(
278
+ "--add_visual_token",
279
+ default=False,
280
+ action="store_true",
281
+ )
282
+ parser.add_argument(
283
+ "--use_format_v2",
284
+ default=False,
285
+ action="store_true",
286
+ )
287
+ parser.add_argument(
288
+ "--eval_aro",
289
+ default=False,
290
+ action="store_true",
291
+ )
292
+ parser.add_argument(
293
+ "--eval_pisc",
294
+ default=False,
295
+ action="store_true",
296
+ )
297
+
298
+
299
+ class OKVQAPostProcess():
300
+ def __init__(self):
301
+ self._lemmatizer = None
302
+
303
+ def _lemmatize(self, answers):
304
+ def apply(answer):
305
+ doc = self.lemmatizer(answer)
306
+
307
+ words = []
308
+ for token in doc:
309
+ if token.pos_ in ["NOUN", "VERB"]:
310
+ words.append(token.lemma_)
311
+ else:
312
+ words.append(token.text)
313
+ answer = " ".join(words)
314
+
315
+ return answer
316
+
317
+ return [apply(answer) for answer in answers]
318
+
319
+ @property
320
+ def lemmatizer(self):
321
+ if self._lemmatizer is None:
322
+ try:
323
+ import spacy
324
+
325
+ self._lemmatizer = spacy.load("en_core_web_sm")
326
+ except ImportError:
327
+ logging.error(
328
+ """
329
+ Please install spacy and en_core_web_sm model to apply lemmatization.
330
+ python -m spacy download en_core_web_sm
331
+ OR
332
+ import spacy.cli
333
+ spacy.cli.download("en_core_web_sm")
334
+ """
335
+ )
336
+ exit(1)
337
+
338
+ return self._lemmatizer
339
+
340
+
341
+ def main():
342
+ args = parser.parse_args()
343
+ if args.dist:
344
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
345
+ print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
346
+ device_id = init_distributed_device(args)
347
+ else:
348
+ args.rank = 0
349
+ args.world_size = 1
350
+ print(f"rank: {args.rank} world_size: {args.world_size}")
351
+
352
+ if "sam" in args.checkpoint_path:
353
+ args.use_sam = "vit_l"
354
+
355
+ args.add_visual_token = True
356
+ if "lora" in args.checkpoint_path:
357
+ args.lora = True
358
+
359
+
360
+ args.add_pe = False
361
+ args.add_box = True
362
+ args.relation = False
363
+ args.enhance_data = False
364
+ args.use_format_v2 = True
365
+
366
+
367
+
368
+ import hashlib
369
+ args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
370
+
371
+ # load model
372
+ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
373
+ args.vision_encoder_path,
374
+ args.vision_encoder_pretrained,
375
+ args.lm_path,
376
+ args.lm_tokenizer_path,
377
+ location_token_num=args.location_token_num,
378
+ lora=args.lora,
379
+ lora_r=16,
380
+ use_sam=args.use_sam,
381
+ add_visual_token=args.add_visual_token,
382
+ use_format_v2=args.use_format_v2,
383
+ add_box=args.add_box,
384
+ add_pe=args.add_pe,
385
+ add_relation=args.relation,
386
+ enhance_data=args.enhance_data,
387
+ )
388
+ flamingo.use_format_v2 = args.use_format_v2
389
+ if args.special:
390
+ flamingo.special = True
391
+ else:
392
+ flamingo.special = False
393
+ if args.legacy:
394
+ flamingo.legacy = True
395
+ print("use legacy evaluation")
396
+ flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
397
+ flamingo.expr_name = args.checkpoint_path.split("/")[-2]
398
+ if args.rank == 0:
399
+ print("legacy", True if hasattr(flamingo, "legacy") else False)
400
+ print("step:", flamingo.step_num)
401
+ print("expr:", flamingo.expr_name)
402
+ print("use format v2:", flamingo.use_format_v2)
403
+ print(args)
404
+ checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
405
+ model_state_dict = {}
406
+ for key in checkpoint["model_state_dict"].keys():
407
+ model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
408
+ if "vision_encoder.logit_scale"in model_state_dict:
409
+ # previous checkpoint has some unnecessary weights
410
+ del model_state_dict["vision_encoder.logit_scale"]
411
+ del model_state_dict["vision_encoder.visual.proj"]
412
+ del model_state_dict["vision_encoder.visual.ln_post.weight"]
413
+ del model_state_dict["vision_encoder.visual.ln_post.bias"]
414
+ flamingo.load_state_dict(model_state_dict, strict=True)
415
+ results = defaultdict(list)
416
+ if args.eval_coco:
417
+ print("Evaluating on COCO...")
418
+ for shot in args.shots:
419
+ scores = []
420
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
421
+ cider_score = evaluate_coco_flickr(
422
+ model=flamingo,
423
+ tokenizer=tokenizer,
424
+ image_processor=image_processor,
425
+ batch_size=args.batch_size,
426
+ image_dir_path=args.coco_image_dir_path,
427
+ annotations_json_path=args.coco_annotations_json_path,
428
+ device=args.device,
429
+ seed=seed,
430
+ vis_embed_size=vis_embed_size,
431
+ rank=args.rank,
432
+ world_size=args.world_size,
433
+ id=args.id,
434
+ )
435
+ print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
436
+ scores.append(cider_score)
437
+ print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
438
+ results["coco"].append(
439
+ {"shots": shot, "trials": scores, "mean": np.mean(scores)}
440
+ )
441
+
442
+ if args.eval_ok_vqa:
443
+ print("Evaluating on OK-VQA...")
444
+ for shot in args.shots:
445
+ scores = []
446
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
447
+ ok_vqa_score = evaluate_vqa(
448
+ model=flamingo,
449
+ tokenizer=tokenizer,
450
+ image_processor=image_processor,
451
+ batch_size=args.batch_size,
452
+ image_dir_path=args.ok_vqa_image_dir_path,
453
+ questions_json_path=args.ok_vqa_questions_json_path,
454
+ annotations_json_path=args.ok_vqa_annotations_json_path,
455
+ vqa_dataset="ok_vqa",
456
+ vis_embed_size=vis_embed_size,
457
+ rank=args.rank,
458
+ world_size=args.world_size,
459
+ id=args.id,
460
+ )
461
+ results["ok_vqa"].append(
462
+ {"shots": shot, "score": ok_vqa_score}
463
+ )
464
+
465
+ if args.eval_vqav2:
466
+ print("Evaluating on VQAv2...")
467
+ for shot in args.shots:
468
+ scores = []
469
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
470
+ vqa_score = evaluate_vqa(
471
+ model=flamingo,
472
+ tokenizer=tokenizer,
473
+ image_processor=image_processor,
474
+ batch_size=args.batch_size,
475
+ image_dir_path=args.vqav2_image_dir_path,
476
+ questions_json_path=args.vqav2_questions_json_path,
477
+ annotations_json_path=args.vqav2_annotations_json_path,
478
+ vqa_dataset="vqa",
479
+ vis_embed_size=vis_embed_size,
480
+ rank=args.rank,
481
+ world_size=args.world_size,
482
+ id=args.id,
483
+ )
484
+ results["vqav2"].append(
485
+ {"shots": shot, "score": vqa_score}
486
+ )
487
+
488
+ if args.eval_gqa:
489
+ print("Evaluating on GQA...")
490
+ for shot in args.shots:
491
+ scores = []
492
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
493
+ vqa_score = evaluate_vqa(
494
+ model=flamingo,
495
+ tokenizer=tokenizer,
496
+ image_processor=image_processor,
497
+ batch_size=args.batch_size,
498
+ vqa_dataset="gqa",
499
+ vis_embed_size=vis_embed_size,
500
+ rank=args.rank,
501
+ world_size=args.world_size,
502
+ id=args.id,
503
+ )
504
+ results["gqa"].append(
505
+ {"shots": shot, "score": vqa_score}
506
+ )
507
+
508
+ if args.eval_imagenet:
509
+ print("Evaluating on ImageNet...")
510
+ for shot in args.shots:
511
+ scores = []
512
+ for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
513
+ imagenet_score = evaluate_imagenet(
514
+ model=flamingo,
515
+ tokenizer=tokenizer,
516
+ image_processor=image_processor,
517
+ batch_size=args.batch_size,
518
+ num_samples=args.num_samples,
519
+ num_shots=shot,
520
+ device=args.device,
521
+ seed=seed,
522
+ imagenet_root=args.imagenet_root,
523
+ )
524
+ print(
525
+ f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
526
+ )
527
+ scores.append(imagenet_score)
528
+ print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
529
+ results["imagenet"].append(
530
+ {"shots": shot, "trials": scores, "mean": np.mean(scores)}
531
+ )
532
+
533
+ if args.eval_refcoco:
534
+ print("Evaluating on RefCOCO...")
535
+ refcoco_score = evaluate_refcoco(
536
+ model=flamingo,
537
+ tokenizer=tokenizer,
538
+ image_processor=image_processor,
539
+ batch_size=args.batch_size,
540
+ device=args.device,
541
+ tsvfile=args.refcoco_tsvfile,
542
+ vis_embed_size=vis_embed_size,
543
+ rank=args.rank,
544
+ world_size=args.world_size,
545
+ id=args.id,
546
+ )
547
+ results["refcoco"].append(
548
+ {"score": refcoco_score}
549
+ )
550
+ if args.eval_aro:
551
+ print("Evaluating on ARO...")
552
+ _func = evaluate_aro
553
+ # print("Evaluating on ARO ORI...")
554
+ # _func = evaluate_aro_ori
555
+ aro_score = _func(
556
+ model=flamingo,
557
+ tokenizer=tokenizer,
558
+ image_processor=image_processor,
559
+ batch_size=args.batch_size,
560
+ device=args.device,
561
+ tsvfile=args.refcoco_tsvfile,
562
+ vis_embed_size=vis_embed_size,
563
+ rank=args.rank,
564
+ world_size=args.world_size,
565
+ id=args.id,
566
+ add_relation=args.relation,
567
+ )
568
+ results["aro"].append(
569
+ {"score": aro_score}
570
+ )
571
+ if args.eval_pisc:
572
+ print("Evaluating on ARO...")
573
+ aro_score = evaluate_pisc(
574
+ model=flamingo,
575
+ tokenizer=tokenizer,
576
+ image_processor=image_processor,
577
+ batch_size=args.batch_size,
578
+ device=args.device,
579
+ tsvfile=args.refcoco_tsvfile,
580
+ vis_embed_size=vis_embed_size,
581
+ rank=args.rank,
582
+ world_size=args.world_size,
583
+ id=args.id,
584
+ )
585
+ results["pisc"].append(
586
+ {"score": aro_score}
587
+ )
588
+
589
+ def prepare_batch_images(batch, image_processor):
590
+ batch_images = None
591
+ for b in batch:
592
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
593
+ if batch_images is None:
594
+ batch_images = b_image
595
+ else:
596
+ batch_images = torch.cat([batch_images, b_image], dim=0)
597
+ return batch_images
598
+
599
+ def get_outputs(
600
+ model,
601
+ batch_images,
602
+ attention_mask,
603
+ max_generation_length,
604
+ min_generation_length,
605
+ num_beams,
606
+ length_penalty,
607
+ input_ids,
608
+ image_start_index_list=None,
609
+ image_nums=None,
610
+ bad_words_ids=None,
611
+ ):
612
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
613
+ outputs = model.generate(
614
+ batch_images,
615
+ input_ids,
616
+ attention_mask=attention_mask,
617
+ max_new_tokens=max_generation_length,
618
+ min_length=min_generation_length,
619
+ num_beams=num_beams,
620
+ length_penalty=length_penalty,
621
+ image_start_index_list=image_start_index_list,
622
+ image_nums=image_nums,
623
+ bad_words_ids=bad_words_ids,
624
+ )
625
+
626
+ outputs = outputs[:, len(input_ids[0]) :]
627
+ return outputs
628
+
629
+
630
+ def evaluate_coco_flickr(
631
+ model,
632
+ tokenizer,
633
+ image_processor,
634
+ batch_size,
635
+ image_dir_path,
636
+ annotations_json_path,
637
+ seed=42,
638
+ max_generation_length=20,
639
+ num_beams=1,
640
+ length_penalty=-2.0,
641
+ device=-1,
642
+ is_flickr=False,
643
+ vis_embed_size=None,
644
+ rank=0,
645
+ world_size=1,
646
+ id=0,
647
+ ):
648
+ """Evaluate a model on COCO dataset.
649
+
650
+ Args:
651
+ model (nn.Module): model to evaluate
652
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
653
+ image_processor : image processor for the model
654
+ batch_size (int): batch size
655
+ image_dir_path (str, optional): path to the directory containing the images.
656
+ annotations_json_path (str, optional): path to the json file containing the annotations.
657
+ seed (int, optional): seed for random number generator. Defaults to 42.
658
+ max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
659
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
660
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
661
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
662
+ query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
663
+ num_shots (int, optional): number of in-context samples to use. Defaults to 8.
664
+ device (int, optional): device to use. Defaults to -1.
665
+ num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
666
+ is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
667
+
668
+ Returns:
669
+ float: CIDEr score
670
+
671
+ """
672
+ # eval_dataset = COCOFlickrDataset(
673
+ # image_dir_path=image_dir_path,
674
+ # annotations_path=annotations_json_path,
675
+ # is_flickr=is_flickr,
676
+ # )
677
+ coco_dataset = load_dataset("coco_caption")
678
+ eval_dataset = coco_dataset["test"]
679
+
680
+
681
+ model.eval().cuda()
682
+ predictions = defaultdict()
683
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
684
+ # if "peft" in lang_encoder_name:
685
+ # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
686
+ try:
687
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
688
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
689
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
690
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
691
+ except:
692
+ pass
693
+
694
+ def get_prompt(sample):
695
+ return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
696
+
697
+ tokenizer.padding_side = "left"
698
+ cnt = 0
699
+ if world_size > 1:
700
+ torch.distributed.barrier()
701
+ desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
702
+ for ii, batch in enumerate(more_itertools.chunked(
703
+ tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
704
+ )):
705
+ if ii % world_size != rank:
706
+ continue
707
+ cnt += len(batch)
708
+ batch_images = prepare_batch_images(
709
+ batch=batch,
710
+ image_processor=image_processor,
711
+ ).cuda()
712
+ batch_text = [get_prompt(s) for s in batch]
713
+ encodings = tokenizer(
714
+ batch_text,
715
+ padding="longest",
716
+ truncation=True,
717
+ return_tensors="pt",
718
+ max_length=2000,
719
+ )
720
+ input_ids = encodings["input_ids"].cuda()
721
+ attention_mask = encodings["attention_mask"].cuda()
722
+ skip_special_tokens = False
723
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
724
+ if rank == 0:
725
+ tqdm.write("use legacy model")
726
+ skip_special_tokens = True
727
+ for i in range(len(input_ids)):
728
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
729
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
730
+ input_ids[i, media_token_index - 1] = media_token_id
731
+ input_ids[i, media_token_index] = pad_token_id
732
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
733
+ input_ids[i, endofmedia_token_index] = bos_token_id
734
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
735
+ image_start_index_list = [[x] for x in image_start_index_list]
736
+ image_nums = [1] * len(input_ids)
737
+ if "llama" in lang_encoder_name:
738
+ attention_mask[input_ids == 0] = 0
739
+ outputs = get_outputs(
740
+ model=model,
741
+ batch_images=batch_images,
742
+ attention_mask=attention_mask,
743
+ max_generation_length=30,
744
+ min_generation_length=8,
745
+ num_beams=5,
746
+ length_penalty=0,
747
+ input_ids=input_ids,
748
+ image_start_index_list=image_start_index_list,
749
+ image_nums=image_nums,
750
+ )
751
+ new_predictions = [
752
+ postprocess_captioning_generation(out).replace('"', "")
753
+ for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
754
+ ]
755
+ # if rank == 0:
756
+ # tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
757
+
758
+ for i, sample in enumerate(batch):
759
+ predictions[int(sample["image_id"])] = {
760
+ "caption": new_predictions[i],
761
+ }
762
+ results_path = (
763
+ f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
764
+ if is_flickr
765
+ else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
766
+ )
767
+ with open(results_path, "w") as f:
768
+ f.write(
769
+ json.dumps(
770
+ [
771
+ {"image_id": k, "caption": predictions[k]["caption"]}
772
+ for k in predictions
773
+ ],
774
+ indent=2,
775
+ )
776
+ )
777
+ print("save to", results_path)
778
+ del predictions
779
+ time.sleep(10)
780
+ if world_size > 1:
781
+ torch.distributed.barrier()
782
+ if rank == 0:
783
+ print(f"evaluate on rank {rank}. world size is {world_size}")
784
+ predictions = []
785
+ for rank_i in range(world_size):
786
+ part_results_path = (
787
+ f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
788
+ if is_flickr
789
+ else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
790
+ )
791
+ print("load", part_results_path)
792
+ predictions.extend(json.load(open(part_results_path)))
793
+ os.remove(part_results_path)
794
+ print("num:", len(predictions))
795
+ results_path = (
796
+ f"flickrresults_{lang_encoder_name}.json"
797
+ if is_flickr
798
+ else f"cocoresults_{lang_encoder_name}.json"
799
+ )
800
+ json.dump(predictions, open(results_path, "w"), indent=2)
801
+
802
+ metrics = compute_cider(
803
+ result_path=results_path,
804
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
805
+ )
806
+ os.makedirs("eval_results", exist_ok=True)
807
+ acc = metrics["CIDEr"]
808
+ with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
809
+ f.write(json.dumps(predictions, indent=2))
810
+
811
+ # delete the temporary file
812
+ os.remove(results_path)
813
+ else:
814
+ metrics = {}
815
+ metrics["CIDEr"] = 0.0
816
+
817
+ return metrics["CIDEr"]
818
+
819
+
820
+ def evaluate_vqa(
821
+ model,
822
+ tokenizer,
823
+ image_processor,
824
+ batch_size,
825
+ image_dir_path=None,
826
+ questions_json_path=None,
827
+ annotations_json_path=None,
828
+ vqa_dataset="vqa",
829
+ vis_embed_size=None,
830
+ rank=0,
831
+ world_size=1,
832
+ id=0,
833
+ ):
834
+ """
835
+ Evaluate a model on VQA datasets. Currently supports VQA v2.0.
836
+
837
+ Args:
838
+ model (nn.Module): model to evaluate
839
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
840
+ image_processor : image processor for the model
841
+ batch_size (int): batch size
842
+ image_dir_path (str): path to image directory
843
+ questions_json_path (str): path to questions json file
844
+ annotations_json_path (str): path to annotations json file
845
+ seed (int, optional): random seed. Defaults to 42.
846
+ max_generation_length (int, optional): max generation length. Defaults to 5.
847
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
848
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
849
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
850
+ query_set_size (int, optional): size of the query set. Defaults to 2048.
851
+ num_shots (int, optional): number of shots to use. Defaults to 8.
852
+ device (int, optional): device to use. Defaults to -1 (cpu).
853
+ num_workers (int, optional): number of workers to use. Defaults to 4.
854
+ vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
855
+ Returns:
856
+ float: accuracy score
857
+ """
858
+ if world_size > 1:
859
+ torch.distributed.barrier()
860
+ if vqa_dataset == "gqa":
861
+ eval_dataset = GQADataset()
862
+ else:
863
+ eval_dataset = VQADataset(
864
+ image_dir_path=image_dir_path,
865
+ question_path=questions_json_path,
866
+ annotations_path=annotations_json_path,
867
+ vqa_dataset=vqa_dataset,
868
+ )
869
+ postprocessor = OKVQAPostProcess()
870
+ try:
871
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
872
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
873
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
874
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
875
+ except:
876
+ pass
877
+ def get_prompt(sample):
878
+ return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
879
+ # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
880
+
881
+ model.eval().cuda()
882
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
883
+ if "peft" in lang_encoder_name:
884
+ lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
885
+ predictions = []
886
+ tokenizer.padding_side = "left"
887
+ if world_size > 1:
888
+ torch.distributed.barrier()
889
+ this_tot = 0
890
+ for ii, batch in enumerate(more_itertools.chunked(
891
+ tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
892
+ )):
893
+ if ii % world_size != rank:
894
+ continue
895
+ batch_images = prepare_batch_images(
896
+ batch=batch,
897
+ image_processor=image_processor,
898
+ ).cuda()
899
+ batch_text = [get_prompt(s) for s in batch]
900
+ encodings = tokenizer(
901
+ batch_text,
902
+ return_tensors="pt",
903
+ padding="longest",
904
+ truncation=True,
905
+ max_length=2000,
906
+ )
907
+ input_ids = encodings["input_ids"].cuda()
908
+ attention_mask = encodings["attention_mask"].cuda()
909
+ skip_special_tokens = True
910
+ if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
911
+ if rank == 0:
912
+ tqdm.write("use legacy model")
913
+ for i in range(len(input_ids)):
914
+ media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
915
+ endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
916
+ input_ids[i, media_token_index - 1] = media_token_id
917
+ input_ids[i, media_token_index] = pad_token_id
918
+ input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
919
+ input_ids[i, endofmedia_token_index] = bos_token_id
920
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
921
+ image_start_index_list = [[x] for x in image_start_index_list]
922
+ image_nums = [1] * len(input_ids)
923
+ if "llama" in lang_encoder_name:
924
+ attention_mask[input_ids == 0] = 0
925
+ outputs = get_outputs(
926
+ model=model,
927
+ batch_images=batch_images,
928
+ attention_mask=attention_mask,
929
+ max_generation_length=10,
930
+ min_generation_length=1,
931
+ num_beams=5,
932
+ length_penalty=0,
933
+ input_ids=input_ids,
934
+ image_start_index_list=image_start_index_list,
935
+ image_nums=image_nums,
936
+ )
937
+ # postprocess begin
938
+ new_predictions = [
939
+ out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
940
+ ]
941
+ if vqa_dataset == "ok_vqa":
942
+ new_predictions = postprocessor._lemmatize(new_predictions)
943
+ if model.special:
944
+ for i in range(len(new_predictions)):
945
+ for answer, _ in Counter(batch[i]['answers']).most_common():
946
+ if answer in new_predictions[i]:
947
+ new_predictions[i] = answer
948
+ break
949
+ if "cant" in new_predictions[i] and "no" == answer:
950
+ new_predictions[i] = answer
951
+ break
952
+ if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
953
+ new_predictions[i] = answer
954
+ break
955
+
956
+ this_tot += 1
957
+ if rank == 0 and this_tot % 20 == 0:
958
+ for i in range(1):
959
+ tqdm.write(f"question: {batch[i]['question']}\nanswer: {batch[i]['answers']}model output: " + new_predictions[i])
960
+
961
+ predictions.extend(
962
+ [
963
+ {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
964
+ for p, sample in zip(new_predictions, batch)
965
+ ]
966
+ )
967
+ with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
968
+ f.write(json.dumps(predictions))
969
+ print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
970
+
971
+ time.sleep(10)
972
+ if world_size > 1:
973
+ torch.distributed.barrier()
974
+ if rank == 0:
975
+ print(f"evaluate on rank {rank}. world size is {world_size}")
976
+ predictions = []
977
+ for rank_i in range(world_size):
978
+ print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
979
+ predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
980
+ os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
981
+ print("num:", len(predictions))
982
+ # save the predictions to a temporary file
983
+ random_uuid = str(uuid.uuid4())
984
+ with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
985
+ f.write(json.dumps(predictions, indent=4))
986
+
987
+ if vqa_dataset == "gqa":
988
+ acc = compute_gqa_accuracy(predictions)
989
+ else:
990
+ acc = compute_vqa_accuracy(
991
+ f"{vqa_dataset}results_{random_uuid}.json",
992
+ questions_json_path,
993
+ annotations_json_path,
994
+ vqa_dataset=vqa_dataset,
995
+ )
996
+ print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
997
+ os.makedirs("eval_results", exist_ok=True)
998
+ with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
999
+ f.write(json.dumps(predictions, indent=2))
1000
+
1001
+ # delete the temporary file
1002
+ os.remove(f"{vqa_dataset}results_{random_uuid}.json")
1003
+ else:
1004
+ time.sleep(5)
1005
+ acc = 0.0
1006
+ if world_size > 1:
1007
+ torch.distributed.barrier()
1008
+ return acc
1009
+
1010
+
1011
+ def evaluate_refcoco(
1012
+ model,
1013
+ tokenizer,
1014
+ image_processor,
1015
+ batch_size,
1016
+ tsvfile,
1017
+ max_generation_length=20,
1018
+ num_beams=3,
1019
+ length_penalty=-2.0,
1020
+ device=-1,
1021
+ vis_embed_size=None,
1022
+ rank=0,
1023
+ world_size=1,
1024
+ id=0,
1025
+ ):
1026
+ model.eval().cuda()
1027
+ loc_token_ids = []
1028
+ for i in range(1000):
1029
+ loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
1030
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1031
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1032
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
1033
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
1034
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
1035
+ # all_ids = set(range(model.lang_encoder.lm_head.out_features))
1036
+ # bad_words_ids = list(all_ids - set(loc_token_ids))
1037
+ # bad_words_ids = [[b] for b in bad_words_ids]
1038
+ # min_loc_token_id = min(loc_token_ids)
1039
+ # max_loc_token_id = max(loc_token_ids)
1040
+ total = 0
1041
+ correct = 0
1042
+ ious = []
1043
+ if "refcocog" in tsvfile:
1044
+ dataset_name = "refcocog"
1045
+ elif "refcocoplus" in tsvfile:
1046
+ dataset_name = "refcocoplus"
1047
+ else:
1048
+ dataset_name = "refcoco"
1049
+ with open(tsvfile, "r") as f:
1050
+ lines = f.readlines()
1051
+ pbar = tqdm(lines, disable=(rank != 0))
1052
+ for ii, line in enumerate(pbar):
1053
+ if ii % world_size != rank:
1054
+ continue
1055
+ total += 1
1056
+ line = line.rstrip()
1057
+ uniq_id, image_id, text, region_coord, image = line.split("\t")
1058
+
1059
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
1060
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
1061
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
1062
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
1063
+
1064
+ gt_box = np.array(list(map(float, region_coord.split(","))))
1065
+ width = image.width
1066
+ height = image.height
1067
+ image = image.resize((224, 224))
1068
+ gt_box = gt_box / np.array([width, height, width, height]) * 224
1069
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1070
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
1071
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
1072
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
1073
+ # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
1074
+
1075
+
1076
+ encodings = tokenizer(
1077
+ prompt,
1078
+ padding="longest",
1079
+ truncation=True,
1080
+ return_tensors="pt",
1081
+ max_length=2000,
1082
+ )
1083
+ input_ids = encodings["input_ids"]
1084
+ attention_mask = encodings["attention_mask"]
1085
+ # attention_mask[input_ids == prebox_token_id] = 0
1086
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1087
+ image_start_index_list = [[x] for x in image_start_index_list]
1088
+ image_nums = [1] * len(input_ids)
1089
+ vision_x = batch_images.cuda()
1090
+ lang_x = input_ids.cuda()
1091
+ attention_mask = attention_mask.cuda()
1092
+
1093
+ model.debug_id = 0
1094
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
1095
+ outputs = model(
1096
+ vision_x=vision_x,
1097
+ lang_x=lang_x,
1098
+ attention_mask=attention_mask,
1099
+ labels=None,
1100
+ image_nums=image_nums,
1101
+ image_start_index_list=image_start_index_list,
1102
+ added_bbox_list=None,
1103
+ add_box=False,
1104
+ )
1105
+ boxes = outputs["boxes"]
1106
+ scores = outputs["scores"]
1107
+ if len(scores) > 0:
1108
+ box = boxes[scores.argmax()]
1109
+ iou = get_iou(box, gt_box)
1110
+ else:
1111
+ iou = 0.0
1112
+ # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
1113
+ tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
1114
+ if iou >= 0.5:
1115
+ correct += 1
1116
+ pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
1117
+ # open_cv_image = np.array(image)
1118
+ # # Convert RGB to BGR
1119
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1120
+ # for box, score in zip(boxes, scores):
1121
+ # open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
1122
+ # cv2.imwrite("output.jpg", open_cv_image)
1123
+ # print(boxes)
1124
+ # print(scores)
1125
+ # exit()
1126
+
1127
+
1128
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1129
+ f.write(json.dumps([total, correct]))
1130
+ if world_size > 1:
1131
+ torch.distributed.barrier()
1132
+ if rank == 0:
1133
+ total = 0
1134
+ correct = 0
1135
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1136
+ for rank_i in range(world_size):
1137
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1138
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1139
+ total += total_part
1140
+ correct += correct_part
1141
+ score = correct / total
1142
+ print("score:", score)
1143
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1144
+ pass
1145
+ else:
1146
+ score = 0.0
1147
+ if world_size > 1:
1148
+ torch.distributed.barrier()
1149
+ return score
1150
+
1151
+
1152
+ def preprocess_visual_info(Text):
1153
+ text = Text.split(" ")
1154
+ for is_idx, t in enumerate(text):
1155
+ if t == "is":
1156
+ break
1157
+ the_idx = is_idx
1158
+ while text[the_idx] != "the":
1159
+ the_idx -= 1
1160
+ obj_A = " ".join(text[the_idx+1:is_idx])
1161
+ second_the_idx = len(text) - 1
1162
+ while text[second_the_idx] != "the":
1163
+ second_the_idx -= 1
1164
+ obj_B = " ".join(text[second_the_idx+1:])
1165
+ relation = " ".join(text[is_idx+1:second_the_idx])
1166
+ visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
1167
+ visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
1168
+ Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
1169
+ return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
1170
+
1171
+
1172
+
1173
+ def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox, debug=False, return_all=False):
1174
+ assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
1175
+ encodings = tokenizer(
1176
+ prompt,
1177
+ padding="longest",
1178
+ truncation=True,
1179
+ return_tensors="pt",
1180
+ max_length=2000,
1181
+ )
1182
+ input_ids = encodings["input_ids"]
1183
+ attention_mask = encodings["attention_mask"]
1184
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1185
+ image_start_index_list = [[x] for x in image_start_index_list]
1186
+ image_nums = [1] * len(input_ids)
1187
+ vision_x = batch_images.cuda()
1188
+ lang_x = input_ids.cuda()
1189
+ attention_mask = attention_mask.cuda()
1190
+ prebox_mask = (input_ids == prebox_token_id)
1191
+ if mask_prebox and prebox_mask.any():
1192
+ attention_mask[prebox_mask] = 0
1193
+
1194
+ model.debug_id = 0
1195
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
1196
+ outputs = model(
1197
+ vision_x=vision_x,
1198
+ lang_x=lang_x,
1199
+ attention_mask=attention_mask,
1200
+ labels=None,
1201
+ image_nums=image_nums,
1202
+ image_start_index_list=image_start_index_list,
1203
+ added_bbox_list=visual_box_list,
1204
+ add_box=visual_box_list is not None,
1205
+ relations=None,
1206
+ debug_mode=False,
1207
+ )
1208
+ boxes = outputs["boxes"]
1209
+ scores = outputs["scores"]
1210
+ if debug:
1211
+ import pdb; pdb.set_trace()
1212
+ if return_all:
1213
+ return boxes, scores
1214
+ if len(scores) == 0:
1215
+ return None, None
1216
+ else:
1217
+ return boxes[scores.argmax()], scores.max()
1218
+
1219
+
1220
+ def evaluate_aro(
1221
+ model,
1222
+ tokenizer,
1223
+ image_processor,
1224
+ batch_size,
1225
+ tsvfile,
1226
+ max_generation_length=20,
1227
+ num_beams=3,
1228
+ length_penalty=-2.0,
1229
+ device=-1,
1230
+ vis_embed_size=None,
1231
+ rank=0,
1232
+ world_size=1,
1233
+ id=0,
1234
+ add_visual=True,
1235
+ add_relation=False,
1236
+ subset=True,
1237
+ choose_left_right=True,
1238
+ ):
1239
+ os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
1240
+ from groundingdino.demo.caption_grounder import caption_grounder
1241
+ generator = caption_grounder(
1242
+ config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
1243
+ checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
1244
+ cpu_only=False,
1245
+ box_threshold=0.1, text_threshold=0.1,
1246
+ )
1247
+ dataset_name = "aro"
1248
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1249
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1250
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1251
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1252
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1253
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1254
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
1255
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
1256
+ model.eval().cuda()
1257
+ total = 0
1258
+ correct = 0
1259
+ from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
1260
+ vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
1261
+ with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f:
1262
+ all_labels = json.load(f)
1263
+ label_ids = tokenizer(all_labels).input_ids
1264
+ label_ids = sorted(list(set([x[0] for x in label_ids])))
1265
+
1266
+ if subset:
1267
+ subset_idx = json.load(open("aro_subset.json"))
1268
+ pbar = tqdm(subset_idx, disable=(rank != 0))
1269
+ else:
1270
+ pbar = tqdm(vgr_dataset, disable=(rank != 0))
1271
+
1272
+
1273
+ exist_total = 0
1274
+ for ii, sample in enumerate(pbar):
1275
+ if subset:
1276
+ ORI_IDX = int(sample)
1277
+ sample = vgr_dataset[sample]
1278
+ # if ORI_IDX != 19036:
1279
+ # continue
1280
+ if ii % world_size != rank:
1281
+ continue
1282
+
1283
+ not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
1284
+ if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
1285
+ if rank == 0:
1286
+ tqdm.write(f"SKIP: {sample['caption_options'][1]}")
1287
+ continue
1288
+ total += 1
1289
+ image = sample["image_options"][0]
1290
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
1291
+ image = image.resize((224, 224))
1292
+
1293
+ chosen_idx = 0
1294
+ text = sample["caption_options"][chosen_idx] # 1 is true caption
1295
+ # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
1296
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1297
+ text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
1298
+
1299
+
1300
+ first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
1301
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
1302
+ first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False)
1303
+
1304
+
1305
+ # use grounding DINO to get the first bbox
1306
+ # caption = f"{obj_A}"
1307
+ # with torch.no_grad():
1308
+ # logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
1309
+ # boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
1310
+ # objects = {}
1311
+ # for box, phrase in zip(boxes_filt, pred_phrases):
1312
+ # obj, score = phrase
1313
+ # obj = obj[0]
1314
+ # if obj not in objects:
1315
+ # objects[obj] = (score, box)
1316
+ # if objects[obj][0] < score:
1317
+ # objects[obj] = (score, box)
1318
+ # try:
1319
+ # first_box = objects[obj_A][1].clone()
1320
+ # first_box[:2] -= first_box[2:] / 2
1321
+ # first_box[2:] += first_box[:2]
1322
+ # first_box = first_box.clamp(0, 0.99) * 224.0
1323
+ # first_box = first_box.numpy()
1324
+ # first_score = objects[obj_A][0]
1325
+ # except:
1326
+ # first_box = None
1327
+
1328
+ if first_box is None:
1329
+ text_A = "the " + obj_A
1330
+ added_bbox_list = None
1331
+ else:
1332
+ text_A = visual_obj_A
1333
+ added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
1334
+
1335
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
1336
+ pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
1337
+ prebox_token_id, mask_prebox=False, debug=False, return_all=True)
1338
+
1339
+
1340
+ open_cv_image = np.array(image)
1341
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
1342
+ font = cv2.FONT_HERSHEY_SIMPLEX
1343
+ fontScale = 0.5
1344
+ color = (0, 0, 0)
1345
+ thickness = 1
1346
+ if first_box is not None:
1347
+ open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
1348
+ exist_flag = False
1349
+ for box, score in zip(pre_boxes, pre_scores):
1350
+ if score >= 0.5:
1351
+ exist_flag = True
1352
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2)
1353
+ org = box[:2].astype(int)
1354
+ org[1] += 20
1355
+ org[0] += 10
1356
+ open_cv_image = cv2.putText(open_cv_image, f"{score:.2f}", org, font, fontScale, (255, 255, 255), thickness, cv2.LINE_AA)
1357
+ open_cv_image = cv2.resize(open_cv_image, (512, 512))
1358
+ put_text = sample["caption_options"][chosen_idx]
1359
+ org = [10, 20]
1360
+ open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1361
+ # cv2.imwrite(f"visualization/aro_results_{id}/{str(ORI_IDX).zfill(8)}.jpg", open_cv_image)
1362
+ if exist_flag:
1363
+ exist_total += 1
1364
+ continue
1365
+
1366
+
1367
+
1368
+ if pre_boxes is None:
1369
+ pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
1370
+ pre_scores = [1.0]
1371
+
1372
+ rank_list = []
1373
+ # pre_boxes = [pre_boxes[0]]
1374
+ # pre_scores = [pre_scores[0]]
1375
+ for pre_box, pre_score in zip(pre_boxes, pre_scores):
1376
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
1377
+
1378
+ encodings = tokenizer(
1379
+ prompt,
1380
+ padding="longest",
1381
+ truncation=True,
1382
+ return_tensors="pt",
1383
+ max_length=512,
1384
+ )
1385
+ input_ids = encodings["input_ids"]
1386
+ attention_mask = encodings["attention_mask"]
1387
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1388
+ image_start_index_list = [[x] for x in image_start_index_list]
1389
+ image_nums = [1] * len(input_ids)
1390
+ vision_x = batch_images.cuda()
1391
+ lang_x = input_ids.cuda()
1392
+ attention_mask = attention_mask.cuda()
1393
+ labels = lang_x.clone()
1394
+
1395
+ answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1
1396
+ # pre_box = None
1397
+ labels[0, :answer_start_idx] = -100
1398
+ # # labels[labels == endofobject_token_id] = -100
1399
+ # labels[:, 0] = -100
1400
+ # labels[labels == visual_token_id] = -100
1401
+ # labels[labels == box_token_id] = -100
1402
+ # labels[labels == previsual_token_id] = -100
1403
+ # labels[labels == prebox_token_id] = -100
1404
+ # labels[labels == endofattr_token_id] = -100
1405
+ # labels[labels == tokenizer.pad_token_id] = -100
1406
+ # labels[labels == media_token_id] = -100
1407
+ # labels[labels == endofmedia_token_id] = -100
1408
+ answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"]
1409
+ labels[input_ids == visual_token_id] = -100
1410
+ labels[input_ids == box_token_id] = -100
1411
+ labels[input_ids == endofattr_token_id] = -100
1412
+ labels[input_ids == previsual_token_id] = -100
1413
+ labels[input_ids == prebox_token_id] = -100
1414
+ labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
1415
+ labels[torch.roll(input_ids == box_token_id, 1)] = -100
1416
+ labels[:, 0] = -100
1417
+ labels[input_ids == tokenizer.pad_token_id] = -100
1418
+ labels[input_ids == media_token_id] = -100
1419
+ labels[input_ids == endofmedia_token_id] = -100
1420
+
1421
+ added_bbox_list = None
1422
+ if add_visual:
1423
+ added_bbox_list = []
1424
+ if first_box is not None:
1425
+ added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
1426
+ if pre_box is not None:
1427
+ added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
1428
+ if added_bbox_list is not None and len(added_bbox_list) == 0:
1429
+ added_bbox_list = None
1430
+
1431
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1432
+ outputs = model(
1433
+ vision_x=vision_x,
1434
+ lang_x=lang_x,
1435
+ attention_mask=attention_mask,
1436
+ labels=labels,
1437
+ image_nums=image_nums,
1438
+ image_start_index_list=image_start_index_list,
1439
+ added_bbox_list=added_bbox_list,
1440
+ add_box=added_bbox_list is not None,
1441
+ relations=None,
1442
+ )
1443
+ logits = outputs["logits"][0, answer_start_idx:]
1444
+ _rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0]))
1445
+ rank_list.append(_rank)
1446
+ # open_cv_image = np.array(image)
1447
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
1448
+ # if first_box is not None:
1449
+ # open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
1450
+ # if pre_box is not None:
1451
+ # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
1452
+
1453
+ # font = cv2.FONT_HERSHEY_SIMPLEX
1454
+ # org = [10, 20]
1455
+ # fontScale = 0.5
1456
+ # color = (0, 0, 0)
1457
+ # thickness = 1
1458
+ # open_cv_image = cv2.resize(open_cv_image, (512, 512))
1459
+ # put_text = sample["caption_options"][1]
1460
+ # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1461
+ # org[1] += 20
1462
+ # put_text = "top10 in green box"
1463
+ # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1464
+ # fontScale = 1.0
1465
+ # thickness = 2
1466
+ # for ind in logits_list[i][0].sort(descending=True).indices[:10]:
1467
+ # org[1] += 20
1468
+ # put_text = f"{tokenizer.decode(ind)}"
1469
+ # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
1470
+ # tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}")
1471
+ # tqdm.write(f"{rank_list}")
1472
+ final_rank = min(rank_list)
1473
+ if final_rank < 10:
1474
+ correct += 1
1475
+ TYPE = "CORRECT"
1476
+ if rank == 0:
1477
+ tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
1478
+ else:
1479
+ TYPE = "WRONG"
1480
+ if rank == 0:
1481
+ tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
1482
+ # cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image)
1483
+ pbar.set_description(f"score: {correct / total:.4f} | {final_rank}")
1484
+
1485
+
1486
+
1487
+
1488
+
1489
+ print(exist_total)
1490
+ exit()
1491
+
1492
+
1493
+
1494
+
1495
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1496
+ f.write(json.dumps([total, correct]))
1497
+ if world_size > 1:
1498
+ torch.distributed.barrier()
1499
+ if rank == 0:
1500
+ total = 0
1501
+ correct = 0
1502
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1503
+ for rank_i in range(world_size):
1504
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1505
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1506
+ total += total_part
1507
+ correct += correct_part
1508
+ score = correct / total
1509
+ print("score:", score, "total:", total)
1510
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1511
+ pass
1512
+ else:
1513
+ score = 0.0
1514
+ if world_size > 1:
1515
+ torch.distributed.barrier()
1516
+ return score
1517
+
1518
+
1519
+
1520
+
1521
+ def evaluate_aro_ori(
1522
+ model,
1523
+ tokenizer,
1524
+ image_processor,
1525
+ batch_size,
1526
+ tsvfile,
1527
+ max_generation_length=20,
1528
+ num_beams=3,
1529
+ length_penalty=-2.0,
1530
+ device=-1,
1531
+ vis_embed_size=None,
1532
+ rank=0,
1533
+ world_size=1,
1534
+ id=0,
1535
+ add_visual=True,
1536
+ add_relation=False,
1537
+ subset=True,
1538
+ choose_left_right=True,
1539
+ only_highest=True,
1540
+ ):
1541
+ os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
1542
+ dataset_name = "aroori"
1543
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1544
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1545
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1546
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1547
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1548
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1549
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
1550
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
1551
+ model.eval().cuda()
1552
+ total = 0
1553
+ correct = 0
1554
+ from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
1555
+ vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
1556
+ if subset:
1557
+ subset_idx = json.load(open("aro_subset.json"))
1558
+ pbar = tqdm(subset_idx, disable=(rank != 0))
1559
+ else:
1560
+ pbar = tqdm(vgr_dataset, disable=(rank != 0))
1561
+ for ii, sample in enumerate(pbar):
1562
+ if subset:
1563
+ ORI_IDX = int(sample)
1564
+ sample = vgr_dataset[sample]
1565
+ # if ORI_IDX != 19036:
1566
+ # continue
1567
+ if ii % world_size != rank:
1568
+ continue
1569
+
1570
+ not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
1571
+ if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
1572
+ if rank == 0:
1573
+ tqdm.write(f"SKIP: {sample['caption_options'][1]}")
1574
+ continue
1575
+ total += 1
1576
+ image = sample["image_options"][0]
1577
+ # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
1578
+ image = image.resize((224, 224))
1579
+ debug_data = []
1580
+ final_losses = []
1581
+ for idx in range(2):
1582
+ text = sample["caption_options"][idx] # 1 is true caption
1583
+ # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
1584
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1585
+ text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
1586
+ first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
1587
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
1588
+ first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False)
1589
+ if first_box is None:
1590
+ text_A = "the " + obj_A
1591
+ added_bbox_list = None
1592
+ else:
1593
+ text_A = visual_obj_A
1594
+ added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
1595
+
1596
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
1597
+ pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
1598
+ prebox_token_id, mask_prebox=False, debug=False, return_all=True)
1599
+ if pre_boxes is None:
1600
+ pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
1601
+ pre_scores = [1.0]
1602
+
1603
+ loss_list = []
1604
+ if only_highest:
1605
+ pre_boxes = [pre_boxes[0]]
1606
+ pre_scores = [pre_scores[0]]
1607
+ for pre_box, pre_score in zip(pre_boxes, pre_scores):
1608
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
1609
+
1610
+ encodings = tokenizer(
1611
+ prompt,
1612
+ padding="longest",
1613
+ truncation=True,
1614
+ return_tensors="pt",
1615
+ max_length=512,
1616
+ )
1617
+ input_ids = encodings["input_ids"]
1618
+ attention_mask = encodings["attention_mask"]
1619
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1620
+ image_start_index_list = [[x] for x in image_start_index_list]
1621
+ image_nums = [1] * len(input_ids)
1622
+ vision_x = batch_images.cuda()
1623
+ lang_x = input_ids.cuda()
1624
+ attention_mask = attention_mask.cuda()
1625
+ labels = lang_x.clone()
1626
+
1627
+
1628
+ labels[input_ids == visual_token_id] = -100
1629
+ labels[input_ids == box_token_id] = -100
1630
+ labels[input_ids == endofattr_token_id] = -100
1631
+ labels[input_ids == previsual_token_id] = -100
1632
+ labels[input_ids == prebox_token_id] = -100
1633
+ labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
1634
+ labels[torch.roll(input_ids == box_token_id, 1)] = -100
1635
+ labels[:, 0] = -100
1636
+ labels[input_ids == tokenizer.pad_token_id] = -100
1637
+ labels[input_ids == media_token_id] = -100
1638
+ labels[input_ids == endofmedia_token_id] = -100
1639
+
1640
+ added_bbox_list = None
1641
+ if add_visual:
1642
+ added_bbox_list = []
1643
+ if first_box is not None:
1644
+ added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
1645
+ if pre_box is not None:
1646
+ added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
1647
+ if added_bbox_list is not None and len(added_bbox_list) == 0:
1648
+ added_bbox_list = None
1649
+
1650
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1651
+ outputs = model(
1652
+ vision_x=vision_x,
1653
+ lang_x=lang_x,
1654
+ attention_mask=attention_mask,
1655
+ labels=labels,
1656
+ image_nums=image_nums,
1657
+ image_start_index_list=image_start_index_list,
1658
+ added_bbox_list=added_bbox_list,
1659
+ add_box=added_bbox_list is not None,
1660
+ relations=None,
1661
+ )
1662
+ loss_list.append((outputs["loss"].sum() / (outputs["loss"] != 0).sum()).item())
1663
+ debug_data.append([outputs, first_box, first_score, pre_box, pre_scores])
1664
+ final_loss = min(loss_list)
1665
+ final_losses.append(final_loss)
1666
+ if final_losses[0] >= final_losses[1]:
1667
+ correct += 1
1668
+ else:
1669
+ import pdb; pdb.set_trace()
1670
+ pass
1671
+ pbar.set_description(f"score: {correct / total:.4f} | {final_losses[0]:.2f} vs {final_losses[1]:.2f}")
1672
+
1673
+
1674
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1675
+ f.write(json.dumps([total, correct]))
1676
+ if world_size > 1:
1677
+ torch.distributed.barrier()
1678
+ if rank == 0:
1679
+ total = 0
1680
+ correct = 0
1681
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1682
+ for rank_i in range(world_size):
1683
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1684
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1685
+ total += total_part
1686
+ correct += correct_part
1687
+ score = correct / total
1688
+ print("score:", score, "total:", total)
1689
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
1690
+ pass
1691
+ else:
1692
+ score = 0.0
1693
+ if world_size > 1:
1694
+ torch.distributed.barrier()
1695
+ return score
1696
+
1697
+
1698
+ def evaluate_pisc(
1699
+ model,
1700
+ tokenizer,
1701
+ image_processor,
1702
+ batch_size,
1703
+ tsvfile,
1704
+ max_generation_length=20,
1705
+ num_beams=3,
1706
+ length_penalty=-2.0,
1707
+ device=-1,
1708
+ vis_embed_size=None,
1709
+ rank=0,
1710
+ world_size=1,
1711
+ id=0,
1712
+ add_visual=True,
1713
+ ):
1714
+ from open_flamingo.train.instruction_template import PISC_TEMPLATES
1715
+ dataset_name = "pisc"
1716
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
1717
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
1718
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
1719
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
1720
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
1721
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
1722
+ model.train().cuda()
1723
+
1724
+ dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
1725
+ pbar = tqdm(dataset, disable=(rank != 0))
1726
+
1727
+ rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
1728
+ rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
1729
+ gt = []
1730
+ pred_scores = []
1731
+ for III, sample in enumerate(pbar):
1732
+ if III % world_size != rank:
1733
+ continue
1734
+ image_path, dataset, data = sample
1735
+ image = Image.open(image_path)
1736
+ size = image_processor.transforms[0].size
1737
+ image = image.resize((size, size))
1738
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
1739
+ boxA = data[0]
1740
+ boxB = data[1]
1741
+ gt_relation = data[2]
1742
+ losses = []
1743
+ for i_rel, option_rel in enumerate(rel_id_to_type):
1744
+ text = PISC_TEMPLATES[0].format(relation=option_rel)
1745
+ added_bbox = [
1746
+ torch.tensor([boxA]).cuda(),
1747
+ torch.tensor([boxB]).cuda(),
1748
+ ]
1749
+ caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
1750
+ encodings = tokenizer(
1751
+ caption,
1752
+ padding="longest",
1753
+ truncation=True,
1754
+ return_tensors="pt",
1755
+ max_length=2000,
1756
+ )
1757
+ input_ids = encodings["input_ids"]
1758
+ attention_mask = encodings["attention_mask"]
1759
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
1760
+ image_start_index_list = [[x] for x in image_start_index_list]
1761
+ image_nums = [1] * len(input_ids)
1762
+ vision_x = batch_images.cuda()
1763
+ lang_x = input_ids.cuda()
1764
+ attention_mask = attention_mask.cuda()
1765
+
1766
+ labels = lang_x.clone()
1767
+ labels[labels == tokenizer.pad_token_id] = -100
1768
+ if add_visual:
1769
+ # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
1770
+ # endofattr_next_token_index[1] += 1
1771
+ # endofattr_next_token_id = labels[endofattr_next_token_index]
1772
+ # </obj><visual><box></attr>NEXT_WORD
1773
+ # </obj> predict NEXT_WORD
1774
+ # <visual><box></attr> predict nothing
1775
+ labels[labels == visual_token_id] = -100
1776
+ labels[labels == box_token_id] = -100
1777
+ labels[labels == endofattr_token_id] = -100
1778
+ # labels[endofattr_next_token_index] = -100
1779
+ labels[:, 0] = -100
1780
+ answer_token_id = tokenizer(" Answer").input_ids[0]
1781
+ answer_token_loc = (input_ids == answer_token_id).nonzero()
1782
+ for batch_idx, idx in answer_token_loc:
1783
+ labels[batch_idx][:idx+2] = -100
1784
+
1785
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
1786
+ outputs = model(
1787
+ vision_x=vision_x,
1788
+ lang_x=lang_x,
1789
+ attention_mask=attention_mask,
1790
+ labels=labels,
1791
+ image_nums=image_nums,
1792
+ image_start_index_list=image_start_index_list,
1793
+ added_bbox_list=added_bbox,
1794
+ add_box=added_bbox is not None,
1795
+ )
1796
+ loss_total = outputs.loss.reshape(labels.shape[0], -1)
1797
+ loss = loss_total.sum() / (loss_total != 0).sum()
1798
+ losses.append(loss.item())
1799
+ pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
1800
+ gt.append(rel_type_to_id[gt_relation])
1801
+ gt = np.array(gt)
1802
+ pred_scores = np.array(pred_scores)
1803
+ pred = pred_scores.argmax(1)
1804
+
1805
+
1806
+ print("total num:", len(gt))
1807
+ recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
1808
+ print("recalls:", recalls)
1809
+
1810
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
1811
+ f.write(json.dumps([gt.tolist(), pred.tolist()]))
1812
+ if world_size > 1:
1813
+ torch.distributed.barrier()
1814
+ if rank == 0:
1815
+ gt = []
1816
+ pred = []
1817
+ print(f"evaluate on rank {rank}. world size is {world_size}")
1818
+ for rank_i in range(world_size):
1819
+ [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
1820
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
1821
+ gt.extend(gt_part)
1822
+ pred.extend(pred_part)
1823
+ print("total num:", len(gt))
1824
+ recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
1825
+ print("recalls:", recalls)
1826
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
1827
+ f.write(f"{gt}\n")
1828
+ f.write(f"{pred}\n")
1829
+ f.write(f"{recalls}\n")
1830
+ score = 0.0
1831
+ if world_size > 1:
1832
+ torch.distributed.barrier()
1833
+ return score
1834
+
1835
+
1836
+
1837
+ if __name__ == "__main__":
1838
+ main()
multimodal/build/lib/open_flamingo/eval/imagenet_utils.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1
2
+ openai_imagenet_classnames = [
3
+ "tench",
4
+ "goldfish",
5
+ "great white shark",
6
+ "tiger shark",
7
+ "hammerhead shark",
8
+ "electric ray",
9
+ "stingray",
10
+ "rooster",
11
+ "hen",
12
+ "ostrich",
13
+ "brambling",
14
+ "goldfinch",
15
+ "house finch",
16
+ "junco",
17
+ "indigo bunting",
18
+ "American robin",
19
+ "bulbul",
20
+ "jay",
21
+ "magpie",
22
+ "chickadee",
23
+ "American dipper",
24
+ "kite (bird of prey)",
25
+ "bald eagle",
26
+ "vulture",
27
+ "great grey owl",
28
+ "fire salamander",
29
+ "smooth newt",
30
+ "newt",
31
+ "spotted salamander",
32
+ "axolotl",
33
+ "American bullfrog",
34
+ "tree frog",
35
+ "tailed frog",
36
+ "loggerhead sea turtle",
37
+ "leatherback sea turtle",
38
+ "mud turtle",
39
+ "terrapin",
40
+ "box turtle",
41
+ "banded gecko",
42
+ "green iguana",
43
+ "Carolina anole",
44
+ "desert grassland whiptail lizard",
45
+ "agama",
46
+ "frilled-necked lizard",
47
+ "alligator lizard",
48
+ "Gila monster",
49
+ "European green lizard",
50
+ "chameleon",
51
+ "Komodo dragon",
52
+ "Nile crocodile",
53
+ "American alligator",
54
+ "triceratops",
55
+ "worm snake",
56
+ "ring-necked snake",
57
+ "eastern hog-nosed snake",
58
+ "smooth green snake",
59
+ "kingsnake",
60
+ "garter snake",
61
+ "water snake",
62
+ "vine snake",
63
+ "night snake",
64
+ "boa constrictor",
65
+ "African rock python",
66
+ "Indian cobra",
67
+ "green mamba",
68
+ "sea snake",
69
+ "Saharan horned viper",
70
+ "eastern diamondback rattlesnake",
71
+ "sidewinder rattlesnake",
72
+ "trilobite",
73
+ "harvestman",
74
+ "scorpion",
75
+ "yellow garden spider",
76
+ "barn spider",
77
+ "European garden spider",
78
+ "southern black widow",
79
+ "tarantula",
80
+ "wolf spider",
81
+ "tick",
82
+ "centipede",
83
+ "black grouse",
84
+ "ptarmigan",
85
+ "ruffed grouse",
86
+ "prairie grouse",
87
+ "peafowl",
88
+ "quail",
89
+ "partridge",
90
+ "african grey parrot",
91
+ "macaw",
92
+ "sulphur-crested cockatoo",
93
+ "lorikeet",
94
+ "coucal",
95
+ "bee eater",
96
+ "hornbill",
97
+ "hummingbird",
98
+ "jacamar",
99
+ "toucan",
100
+ "duck",
101
+ "red-breasted merganser",
102
+ "goose",
103
+ "black swan",
104
+ "tusker",
105
+ "echidna",
106
+ "platypus",
107
+ "wallaby",
108
+ "koala",
109
+ "wombat",
110
+ "jellyfish",
111
+ "sea anemone",
112
+ "brain coral",
113
+ "flatworm",
114
+ "nematode",
115
+ "conch",
116
+ "snail",
117
+ "slug",
118
+ "sea slug",
119
+ "chiton",
120
+ "chambered nautilus",
121
+ "Dungeness crab",
122
+ "rock crab",
123
+ "fiddler crab",
124
+ "red king crab",
125
+ "American lobster",
126
+ "spiny lobster",
127
+ "crayfish",
128
+ "hermit crab",
129
+ "isopod",
130
+ "white stork",
131
+ "black stork",
132
+ "spoonbill",
133
+ "flamingo",
134
+ "little blue heron",
135
+ "great egret",
136
+ "bittern bird",
137
+ "crane bird",
138
+ "limpkin",
139
+ "common gallinule",
140
+ "American coot",
141
+ "bustard",
142
+ "ruddy turnstone",
143
+ "dunlin",
144
+ "common redshank",
145
+ "dowitcher",
146
+ "oystercatcher",
147
+ "pelican",
148
+ "king penguin",
149
+ "albatross",
150
+ "grey whale",
151
+ "killer whale",
152
+ "dugong",
153
+ "sea lion",
154
+ "Chihuahua",
155
+ "Japanese Chin",
156
+ "Maltese",
157
+ "Pekingese",
158
+ "Shih Tzu",
159
+ "King Charles Spaniel",
160
+ "Papillon",
161
+ "toy terrier",
162
+ "Rhodesian Ridgeback",
163
+ "Afghan Hound",
164
+ "Basset Hound",
165
+ "Beagle",
166
+ "Bloodhound",
167
+ "Bluetick Coonhound",
168
+ "Black and Tan Coonhound",
169
+ "Treeing Walker Coonhound",
170
+ "English foxhound",
171
+ "Redbone Coonhound",
172
+ "borzoi",
173
+ "Irish Wolfhound",
174
+ "Italian Greyhound",
175
+ "Whippet",
176
+ "Ibizan Hound",
177
+ "Norwegian Elkhound",
178
+ "Otterhound",
179
+ "Saluki",
180
+ "Scottish Deerhound",
181
+ "Weimaraner",
182
+ "Staffordshire Bull Terrier",
183
+ "American Staffordshire Terrier",
184
+ "Bedlington Terrier",
185
+ "Border Terrier",
186
+ "Kerry Blue Terrier",
187
+ "Irish Terrier",
188
+ "Norfolk Terrier",
189
+ "Norwich Terrier",
190
+ "Yorkshire Terrier",
191
+ "Wire Fox Terrier",
192
+ "Lakeland Terrier",
193
+ "Sealyham Terrier",
194
+ "Airedale Terrier",
195
+ "Cairn Terrier",
196
+ "Australian Terrier",
197
+ "Dandie Dinmont Terrier",
198
+ "Boston Terrier",
199
+ "Miniature Schnauzer",
200
+ "Giant Schnauzer",
201
+ "Standard Schnauzer",
202
+ "Scottish Terrier",
203
+ "Tibetan Terrier",
204
+ "Australian Silky Terrier",
205
+ "Soft-coated Wheaten Terrier",
206
+ "West Highland White Terrier",
207
+ "Lhasa Apso",
208
+ "Flat-Coated Retriever",
209
+ "Curly-coated Retriever",
210
+ "Golden Retriever",
211
+ "Labrador Retriever",
212
+ "Chesapeake Bay Retriever",
213
+ "German Shorthaired Pointer",
214
+ "Vizsla",
215
+ "English Setter",
216
+ "Irish Setter",
217
+ "Gordon Setter",
218
+ "Brittany dog",
219
+ "Clumber Spaniel",
220
+ "English Springer Spaniel",
221
+ "Welsh Springer Spaniel",
222
+ "Cocker Spaniel",
223
+ "Sussex Spaniel",
224
+ "Irish Water Spaniel",
225
+ "Kuvasz",
226
+ "Schipperke",
227
+ "Groenendael dog",
228
+ "Malinois",
229
+ "Briard",
230
+ "Australian Kelpie",
231
+ "Komondor",
232
+ "Old English Sheepdog",
233
+ "Shetland Sheepdog",
234
+ "collie",
235
+ "Border Collie",
236
+ "Bouvier des Flandres dog",
237
+ "Rottweiler",
238
+ "German Shepherd Dog",
239
+ "Dobermann",
240
+ "Miniature Pinscher",
241
+ "Greater Swiss Mountain Dog",
242
+ "Bernese Mountain Dog",
243
+ "Appenzeller Sennenhund",
244
+ "Entlebucher Sennenhund",
245
+ "Boxer",
246
+ "Bullmastiff",
247
+ "Tibetan Mastiff",
248
+ "French Bulldog",
249
+ "Great Dane",
250
+ "St. Bernard",
251
+ "husky",
252
+ "Alaskan Malamute",
253
+ "Siberian Husky",
254
+ "Dalmatian",
255
+ "Affenpinscher",
256
+ "Basenji",
257
+ "pug",
258
+ "Leonberger",
259
+ "Newfoundland dog",
260
+ "Great Pyrenees dog",
261
+ "Samoyed",
262
+ "Pomeranian",
263
+ "Chow Chow",
264
+ "Keeshond",
265
+ "brussels griffon",
266
+ "Pembroke Welsh Corgi",
267
+ "Cardigan Welsh Corgi",
268
+ "Toy Poodle",
269
+ "Miniature Poodle",
270
+ "Standard Poodle",
271
+ "Mexican hairless dog (xoloitzcuintli)",
272
+ "grey wolf",
273
+ "Alaskan tundra wolf",
274
+ "red wolf or maned wolf",
275
+ "coyote",
276
+ "dingo",
277
+ "dhole",
278
+ "African wild dog",
279
+ "hyena",
280
+ "red fox",
281
+ "kit fox",
282
+ "Arctic fox",
283
+ "grey fox",
284
+ "tabby cat",
285
+ "tiger cat",
286
+ "Persian cat",
287
+ "Siamese cat",
288
+ "Egyptian Mau",
289
+ "cougar",
290
+ "lynx",
291
+ "leopard",
292
+ "snow leopard",
293
+ "jaguar",
294
+ "lion",
295
+ "tiger",
296
+ "cheetah",
297
+ "brown bear",
298
+ "American black bear",
299
+ "polar bear",
300
+ "sloth bear",
301
+ "mongoose",
302
+ "meerkat",
303
+ "tiger beetle",
304
+ "ladybug",
305
+ "ground beetle",
306
+ "longhorn beetle",
307
+ "leaf beetle",
308
+ "dung beetle",
309
+ "rhinoceros beetle",
310
+ "weevil",
311
+ "fly",
312
+ "bee",
313
+ "ant",
314
+ "grasshopper",
315
+ "cricket insect",
316
+ "stick insect",
317
+ "cockroach",
318
+ "praying mantis",
319
+ "cicada",
320
+ "leafhopper",
321
+ "lacewing",
322
+ "dragonfly",
323
+ "damselfly",
324
+ "red admiral butterfly",
325
+ "ringlet butterfly",
326
+ "monarch butterfly",
327
+ "small white butterfly",
328
+ "sulphur butterfly",
329
+ "gossamer-winged butterfly",
330
+ "starfish",
331
+ "sea urchin",
332
+ "sea cucumber",
333
+ "cottontail rabbit",
334
+ "hare",
335
+ "Angora rabbit",
336
+ "hamster",
337
+ "porcupine",
338
+ "fox squirrel",
339
+ "marmot",
340
+ "beaver",
341
+ "guinea pig",
342
+ "common sorrel horse",
343
+ "zebra",
344
+ "pig",
345
+ "wild boar",
346
+ "warthog",
347
+ "hippopotamus",
348
+ "ox",
349
+ "water buffalo",
350
+ "bison",
351
+ "ram (adult male sheep)",
352
+ "bighorn sheep",
353
+ "Alpine ibex",
354
+ "hartebeest",
355
+ "impala (antelope)",
356
+ "gazelle",
357
+ "arabian camel",
358
+ "llama",
359
+ "weasel",
360
+ "mink",
361
+ "European polecat",
362
+ "black-footed ferret",
363
+ "otter",
364
+ "skunk",
365
+ "badger",
366
+ "armadillo",
367
+ "three-toed sloth",
368
+ "orangutan",
369
+ "gorilla",
370
+ "chimpanzee",
371
+ "gibbon",
372
+ "siamang",
373
+ "guenon",
374
+ "patas monkey",
375
+ "baboon",
376
+ "macaque",
377
+ "langur",
378
+ "black-and-white colobus",
379
+ "proboscis monkey",
380
+ "marmoset",
381
+ "white-headed capuchin",
382
+ "howler monkey",
383
+ "titi monkey",
384
+ "Geoffroy's spider monkey",
385
+ "common squirrel monkey",
386
+ "ring-tailed lemur",
387
+ "indri",
388
+ "Asian elephant",
389
+ "African bush elephant",
390
+ "red panda",
391
+ "giant panda",
392
+ "snoek fish",
393
+ "eel",
394
+ "silver salmon",
395
+ "rock beauty fish",
396
+ "clownfish",
397
+ "sturgeon",
398
+ "gar fish",
399
+ "lionfish",
400
+ "pufferfish",
401
+ "abacus",
402
+ "abaya",
403
+ "academic gown",
404
+ "accordion",
405
+ "acoustic guitar",
406
+ "aircraft carrier",
407
+ "airliner",
408
+ "airship",
409
+ "altar",
410
+ "ambulance",
411
+ "amphibious vehicle",
412
+ "analog clock",
413
+ "apiary",
414
+ "apron",
415
+ "trash can",
416
+ "assault rifle",
417
+ "backpack",
418
+ "bakery",
419
+ "balance beam",
420
+ "balloon",
421
+ "ballpoint pen",
422
+ "Band-Aid",
423
+ "banjo",
424
+ "baluster / handrail",
425
+ "barbell",
426
+ "barber chair",
427
+ "barbershop",
428
+ "barn",
429
+ "barometer",
430
+ "barrel",
431
+ "wheelbarrow",
432
+ "baseball",
433
+ "basketball",
434
+ "bassinet",
435
+ "bassoon",
436
+ "swimming cap",
437
+ "bath towel",
438
+ "bathtub",
439
+ "station wagon",
440
+ "lighthouse",
441
+ "beaker",
442
+ "military hat (bearskin or shako)",
443
+ "beer bottle",
444
+ "beer glass",
445
+ "bell tower",
446
+ "baby bib",
447
+ "tandem bicycle",
448
+ "bikini",
449
+ "ring binder",
450
+ "binoculars",
451
+ "birdhouse",
452
+ "boathouse",
453
+ "bobsleigh",
454
+ "bolo tie",
455
+ "poke bonnet",
456
+ "bookcase",
457
+ "bookstore",
458
+ "bottle cap",
459
+ "hunting bow",
460
+ "bow tie",
461
+ "brass memorial plaque",
462
+ "bra",
463
+ "breakwater",
464
+ "breastplate",
465
+ "broom",
466
+ "bucket",
467
+ "buckle",
468
+ "bulletproof vest",
469
+ "high-speed train",
470
+ "butcher shop",
471
+ "taxicab",
472
+ "cauldron",
473
+ "candle",
474
+ "cannon",
475
+ "canoe",
476
+ "can opener",
477
+ "cardigan",
478
+ "car mirror",
479
+ "carousel",
480
+ "tool kit",
481
+ "cardboard box / carton",
482
+ "car wheel",
483
+ "automated teller machine",
484
+ "cassette",
485
+ "cassette player",
486
+ "castle",
487
+ "catamaran",
488
+ "CD player",
489
+ "cello",
490
+ "mobile phone",
491
+ "chain",
492
+ "chain-link fence",
493
+ "chain mail",
494
+ "chainsaw",
495
+ "storage chest",
496
+ "chiffonier",
497
+ "bell or wind chime",
498
+ "china cabinet",
499
+ "Christmas stocking",
500
+ "church",
501
+ "movie theater",
502
+ "cleaver",
503
+ "cliff dwelling",
504
+ "cloak",
505
+ "clogs",
506
+ "cocktail shaker",
507
+ "coffee mug",
508
+ "coffeemaker",
509
+ "spiral or coil",
510
+ "combination lock",
511
+ "computer keyboard",
512
+ "candy store",
513
+ "container ship",
514
+ "convertible",
515
+ "corkscrew",
516
+ "cornet",
517
+ "cowboy boot",
518
+ "cowboy hat",
519
+ "cradle",
520
+ "construction crane",
521
+ "crash helmet",
522
+ "crate",
523
+ "infant bed",
524
+ "Crock Pot",
525
+ "croquet ball",
526
+ "crutch",
527
+ "cuirass",
528
+ "dam",
529
+ "desk",
530
+ "desktop computer",
531
+ "rotary dial telephone",
532
+ "diaper",
533
+ "digital clock",
534
+ "digital watch",
535
+ "dining table",
536
+ "dishcloth",
537
+ "dishwasher",
538
+ "disc brake",
539
+ "dock",
540
+ "dog sled",
541
+ "dome",
542
+ "doormat",
543
+ "drilling rig",
544
+ "drum",
545
+ "drumstick",
546
+ "dumbbell",
547
+ "Dutch oven",
548
+ "electric fan",
549
+ "electric guitar",
550
+ "electric locomotive",
551
+ "entertainment center",
552
+ "envelope",
553
+ "espresso machine",
554
+ "face powder",
555
+ "feather boa",
556
+ "filing cabinet",
557
+ "fireboat",
558
+ "fire truck",
559
+ "fire screen",
560
+ "flagpole",
561
+ "flute",
562
+ "folding chair",
563
+ "football helmet",
564
+ "forklift",
565
+ "fountain",
566
+ "fountain pen",
567
+ "four-poster bed",
568
+ "freight car",
569
+ "French horn",
570
+ "frying pan",
571
+ "fur coat",
572
+ "garbage truck",
573
+ "gas mask or respirator",
574
+ "gas pump",
575
+ "goblet",
576
+ "go-kart",
577
+ "golf ball",
578
+ "golf cart",
579
+ "gondola",
580
+ "gong",
581
+ "gown",
582
+ "grand piano",
583
+ "greenhouse",
584
+ "radiator grille",
585
+ "grocery store",
586
+ "guillotine",
587
+ "hair clip",
588
+ "hair spray",
589
+ "half-track",
590
+ "hammer",
591
+ "hamper",
592
+ "hair dryer",
593
+ "hand-held computer",
594
+ "handkerchief",
595
+ "hard disk drive",
596
+ "harmonica",
597
+ "harp",
598
+ "combine harvester",
599
+ "hatchet",
600
+ "holster",
601
+ "home theater",
602
+ "honeycomb",
603
+ "hook",
604
+ "hoop skirt",
605
+ "gymnastic horizontal bar",
606
+ "horse-drawn vehicle",
607
+ "hourglass",
608
+ "iPod",
609
+ "clothes iron",
610
+ "carved pumpkin",
611
+ "jeans",
612
+ "jeep",
613
+ "T-shirt",
614
+ "jigsaw puzzle",
615
+ "rickshaw",
616
+ "joystick",
617
+ "kimono",
618
+ "knee pad",
619
+ "knot",
620
+ "lab coat",
621
+ "ladle",
622
+ "lampshade",
623
+ "laptop computer",
624
+ "lawn mower",
625
+ "lens cap",
626
+ "letter opener",
627
+ "library",
628
+ "lifeboat",
629
+ "lighter",
630
+ "limousine",
631
+ "ocean liner",
632
+ "lipstick",
633
+ "slip-on shoe",
634
+ "lotion",
635
+ "music speaker",
636
+ "loupe magnifying glass",
637
+ "sawmill",
638
+ "magnetic compass",
639
+ "messenger bag",
640
+ "mailbox",
641
+ "tights",
642
+ "one-piece bathing suit",
643
+ "manhole cover",
644
+ "maraca",
645
+ "marimba",
646
+ "mask",
647
+ "matchstick",
648
+ "maypole",
649
+ "maze",
650
+ "measuring cup",
651
+ "medicine cabinet",
652
+ "megalith",
653
+ "microphone",
654
+ "microwave oven",
655
+ "military uniform",
656
+ "milk can",
657
+ "minibus",
658
+ "miniskirt",
659
+ "minivan",
660
+ "missile",
661
+ "mitten",
662
+ "mixing bowl",
663
+ "mobile home",
664
+ "ford model t",
665
+ "modem",
666
+ "monastery",
667
+ "monitor",
668
+ "moped",
669
+ "mortar and pestle",
670
+ "graduation cap",
671
+ "mosque",
672
+ "mosquito net",
673
+ "vespa",
674
+ "mountain bike",
675
+ "tent",
676
+ "computer mouse",
677
+ "mousetrap",
678
+ "moving van",
679
+ "muzzle",
680
+ "metal nail",
681
+ "neck brace",
682
+ "necklace",
683
+ "baby pacifier",
684
+ "notebook computer",
685
+ "obelisk",
686
+ "oboe",
687
+ "ocarina",
688
+ "odometer",
689
+ "oil filter",
690
+ "pipe organ",
691
+ "oscilloscope",
692
+ "overskirt",
693
+ "bullock cart",
694
+ "oxygen mask",
695
+ "product packet / packaging",
696
+ "paddle",
697
+ "paddle wheel",
698
+ "padlock",
699
+ "paintbrush",
700
+ "pajamas",
701
+ "palace",
702
+ "pan flute",
703
+ "paper towel",
704
+ "parachute",
705
+ "parallel bars",
706
+ "park bench",
707
+ "parking meter",
708
+ "railroad car",
709
+ "patio",
710
+ "payphone",
711
+ "pedestal",
712
+ "pencil case",
713
+ "pencil sharpener",
714
+ "perfume",
715
+ "Petri dish",
716
+ "photocopier",
717
+ "plectrum",
718
+ "Pickelhaube",
719
+ "picket fence",
720
+ "pickup truck",
721
+ "pier",
722
+ "piggy bank",
723
+ "pill bottle",
724
+ "pillow",
725
+ "ping-pong ball",
726
+ "pinwheel",
727
+ "pirate ship",
728
+ "drink pitcher",
729
+ "block plane",
730
+ "planetarium",
731
+ "plastic bag",
732
+ "plate rack",
733
+ "farm plow",
734
+ "plunger",
735
+ "Polaroid camera",
736
+ "pole",
737
+ "police van",
738
+ "poncho",
739
+ "pool table",
740
+ "soda bottle",
741
+ "plant pot",
742
+ "potter's wheel",
743
+ "power drill",
744
+ "prayer rug",
745
+ "printer",
746
+ "prison",
747
+ "missile",
748
+ "projector",
749
+ "hockey puck",
750
+ "punching bag",
751
+ "purse",
752
+ "quill",
753
+ "quilt",
754
+ "race car",
755
+ "racket",
756
+ "radiator",
757
+ "radio",
758
+ "radio telescope",
759
+ "rain barrel",
760
+ "recreational vehicle",
761
+ "fishing casting reel",
762
+ "reflex camera",
763
+ "refrigerator",
764
+ "remote control",
765
+ "restaurant",
766
+ "revolver",
767
+ "rifle",
768
+ "rocking chair",
769
+ "rotisserie",
770
+ "eraser",
771
+ "rugby ball",
772
+ "ruler measuring stick",
773
+ "sneaker",
774
+ "safe",
775
+ "safety pin",
776
+ "salt shaker",
777
+ "sandal",
778
+ "sarong",
779
+ "saxophone",
780
+ "scabbard",
781
+ "weighing scale",
782
+ "school bus",
783
+ "schooner",
784
+ "scoreboard",
785
+ "CRT monitor",
786
+ "screw",
787
+ "screwdriver",
788
+ "seat belt",
789
+ "sewing machine",
790
+ "shield",
791
+ "shoe store",
792
+ "shoji screen / room divider",
793
+ "shopping basket",
794
+ "shopping cart",
795
+ "shovel",
796
+ "shower cap",
797
+ "shower curtain",
798
+ "ski",
799
+ "balaclava ski mask",
800
+ "sleeping bag",
801
+ "slide rule",
802
+ "sliding door",
803
+ "slot machine",
804
+ "snorkel",
805
+ "snowmobile",
806
+ "snowplow",
807
+ "soap dispenser",
808
+ "soccer ball",
809
+ "sock",
810
+ "solar thermal collector",
811
+ "sombrero",
812
+ "soup bowl",
813
+ "keyboard space bar",
814
+ "space heater",
815
+ "space shuttle",
816
+ "spatula",
817
+ "motorboat",
818
+ "spider web",
819
+ "spindle",
820
+ "sports car",
821
+ "spotlight",
822
+ "stage",
823
+ "steam locomotive",
824
+ "through arch bridge",
825
+ "steel drum",
826
+ "stethoscope",
827
+ "scarf",
828
+ "stone wall",
829
+ "stopwatch",
830
+ "stove",
831
+ "strainer",
832
+ "tram",
833
+ "stretcher",
834
+ "couch",
835
+ "stupa",
836
+ "submarine",
837
+ "suit",
838
+ "sundial",
839
+ "sunglasses",
840
+ "sunglasses",
841
+ "sunscreen",
842
+ "suspension bridge",
843
+ "mop",
844
+ "sweatshirt",
845
+ "swim trunks / shorts",
846
+ "swing",
847
+ "electrical switch",
848
+ "syringe",
849
+ "table lamp",
850
+ "tank",
851
+ "tape player",
852
+ "teapot",
853
+ "teddy bear",
854
+ "television",
855
+ "tennis ball",
856
+ "thatched roof",
857
+ "front curtain",
858
+ "thimble",
859
+ "threshing machine",
860
+ "throne",
861
+ "tile roof",
862
+ "toaster",
863
+ "tobacco shop",
864
+ "toilet seat",
865
+ "torch",
866
+ "totem pole",
867
+ "tow truck",
868
+ "toy store",
869
+ "tractor",
870
+ "semi-trailer truck",
871
+ "tray",
872
+ "trench coat",
873
+ "tricycle",
874
+ "trimaran",
875
+ "tripod",
876
+ "triumphal arch",
877
+ "trolleybus",
878
+ "trombone",
879
+ "hot tub",
880
+ "turnstile",
881
+ "typewriter keyboard",
882
+ "umbrella",
883
+ "unicycle",
884
+ "upright piano",
885
+ "vacuum cleaner",
886
+ "vase",
887
+ "vaulted or arched ceiling",
888
+ "velvet fabric",
889
+ "vending machine",
890
+ "vestment",
891
+ "viaduct",
892
+ "violin",
893
+ "volleyball",
894
+ "waffle iron",
895
+ "wall clock",
896
+ "wallet",
897
+ "wardrobe",
898
+ "military aircraft",
899
+ "sink",
900
+ "washing machine",
901
+ "water bottle",
902
+ "water jug",
903
+ "water tower",
904
+ "whiskey jug",
905
+ "whistle",
906
+ "hair wig",
907
+ "window screen",
908
+ "window shade",
909
+ "Windsor tie",
910
+ "wine bottle",
911
+ "airplane wing",
912
+ "wok",
913
+ "wooden spoon",
914
+ "wool",
915
+ "split-rail fence",
916
+ "shipwreck",
917
+ "sailboat",
918
+ "yurt",
919
+ "website",
920
+ "comic book",
921
+ "crossword",
922
+ "traffic or street sign",
923
+ "traffic light",
924
+ "dust jacket",
925
+ "menu",
926
+ "plate",
927
+ "guacamole",
928
+ "consomme",
929
+ "hot pot",
930
+ "trifle",
931
+ "ice cream",
932
+ "popsicle",
933
+ "baguette",
934
+ "bagel",
935
+ "pretzel",
936
+ "cheeseburger",
937
+ "hot dog",
938
+ "mashed potatoes",
939
+ "cabbage",
940
+ "broccoli",
941
+ "cauliflower",
942
+ "zucchini",
943
+ "spaghetti squash",
944
+ "acorn squash",
945
+ "butternut squash",
946
+ "cucumber",
947
+ "artichoke",
948
+ "bell pepper",
949
+ "cardoon",
950
+ "mushroom",
951
+ "Granny Smith apple",
952
+ "strawberry",
953
+ "orange",
954
+ "lemon",
955
+ "fig",
956
+ "pineapple",
957
+ "banana",
958
+ "jackfruit",
959
+ "cherimoya (custard apple)",
960
+ "pomegranate",
961
+ "hay",
962
+ "carbonara",
963
+ "chocolate syrup",
964
+ "dough",
965
+ "meatloaf",
966
+ "pizza",
967
+ "pot pie",
968
+ "burrito",
969
+ "red wine",
970
+ "espresso",
971
+ "tea cup",
972
+ "eggnog",
973
+ "mountain",
974
+ "bubble",
975
+ "cliff",
976
+ "coral reef",
977
+ "geyser",
978
+ "lakeshore",
979
+ "promontory",
980
+ "sandbar",
981
+ "beach",
982
+ "valley",
983
+ "volcano",
984
+ "baseball player",
985
+ "bridegroom",
986
+ "scuba diver",
987
+ "rapeseed",
988
+ "daisy",
989
+ "yellow lady's slipper",
990
+ "corn",
991
+ "acorn",
992
+ "rose hip",
993
+ "horse chestnut seed",
994
+ "coral fungus",
995
+ "agaric",
996
+ "gyromitra",
997
+ "stinkhorn mushroom",
998
+ "earth star fungus",
999
+ "hen of the woods mushroom",
1000
+ "bolete",
1001
+ "corn cob",
1002
+ "toilet paper",
1003
+ ]
1004
+ # Maps numeric class ids to labels
1005
+ IMAGENET_1K_CLASS_ID_TO_LABEL = dict(
1006
+ zip(range(len(openai_imagenet_classnames)), openai_imagenet_classnames)
1007
+ )
multimodal/build/lib/open_flamingo/eval/ok_vqa_utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Those are manual mapping that are not caught by our stemming rules or would
2
+ # would be done incorrectly by our automatic stemming rule. In details,
3
+ # the keys of the _MANUAL_MATCHES dict contains the original word and the value
4
+ # contains the transformation of the word expected by the OKVQA stemming rule.
5
+ # These manual rules were found by checking the `raw_answers` and the `answers`
6
+ # fields of the released OKVQA dataset and checking all things that were not
7
+ # properly mapped by our automatic rules. In particular some of the mapping
8
+ # are sometimes constant, e.g. christmas -> christmas which was incorrectly
9
+ # singularized by our inflection.singularize.
10
+ import re
11
+ import nltk
12
+ from nltk.corpus.reader import VERB
13
+ import inflection
14
+
15
+ _MANUAL_MATCHES = {
16
+ "police": "police",
17
+ "las": "las",
18
+ "vegas": "vegas",
19
+ "yes": "yes",
20
+ "jeans": "jean",
21
+ "hell's": "hell",
22
+ "domino's": "domino",
23
+ "morning": "morn",
24
+ "clothes": "cloth",
25
+ "are": "are",
26
+ "riding": "ride",
27
+ "leaves": "leaf",
28
+ "dangerous": "danger",
29
+ "clothing": "cloth",
30
+ "texting": "text",
31
+ "kiting": "kite",
32
+ "firefighters": "firefight",
33
+ "ties": "tie",
34
+ "married": "married",
35
+ "teething": "teeth",
36
+ "gloves": "glove",
37
+ "tennis": "tennis",
38
+ "dining": "dine",
39
+ "directions": "direct",
40
+ "waves": "wave",
41
+ "christmas": "christmas",
42
+ "drives": "drive",
43
+ "pudding": "pud",
44
+ "coding": "code",
45
+ "plating": "plate",
46
+ "quantas": "quanta",
47
+ "hornes": "horn",
48
+ "graves": "grave",
49
+ "mating": "mate",
50
+ "paned": "pane",
51
+ "alertness": "alert",
52
+ "sunbathing": "sunbath",
53
+ "tenning": "ten",
54
+ "wetness": "wet",
55
+ "urinating": "urine",
56
+ "sickness": "sick",
57
+ "braves": "brave",
58
+ "firefighting": "firefight",
59
+ "lenses": "lens",
60
+ "reflections": "reflect",
61
+ "backpackers": "backpack",
62
+ "eatting": "eat",
63
+ "designers": "design",
64
+ "curiousity": "curious",
65
+ "playfulness": "play",
66
+ "blindness": "blind",
67
+ "hawke": "hawk",
68
+ "tomatoe": "tomato",
69
+ "rodeoing": "rodeo",
70
+ "brightness": "bright",
71
+ "circuses": "circus",
72
+ "skateboarders": "skateboard",
73
+ "staring": "stare",
74
+ "electronics": "electron",
75
+ "electicity": "elect",
76
+ "mountainous": "mountain",
77
+ "socializing": "social",
78
+ "hamburgers": "hamburg",
79
+ "caves": "cave",
80
+ "transitions": "transit",
81
+ "wading": "wade",
82
+ "creame": "cream",
83
+ "toileting": "toilet",
84
+ "sautee": "saute",
85
+ "buildings": "build",
86
+ "belongings": "belong",
87
+ "stockings": "stock",
88
+ "walle": "wall",
89
+ "cumulis": "cumuli",
90
+ "travelers": "travel",
91
+ "conducter": "conduct",
92
+ "browsing": "brows",
93
+ "pooping": "poop",
94
+ "haircutting": "haircut",
95
+ "toppings": "top",
96
+ "hearding": "heard",
97
+ "sunblocker": "sunblock",
98
+ "bases": "base",
99
+ "markings": "mark",
100
+ "mopeds": "mope",
101
+ "kindergartener": "kindergarten",
102
+ "pies": "pie",
103
+ "scrapbooking": "scrapbook",
104
+ "couponing": "coupon",
105
+ "meetings": "meet",
106
+ "elevators": "elev",
107
+ "lowes": "low",
108
+ "men's": "men",
109
+ "childrens": "children",
110
+ "shelves": "shelve",
111
+ "paintings": "paint",
112
+ "raines": "rain",
113
+ "paring": "pare",
114
+ "expressions": "express",
115
+ "routes": "rout",
116
+ "pease": "peas",
117
+ "vastness": "vast",
118
+ "awning": "awn",
119
+ "boy's": "boy",
120
+ "drunkenness": "drunken",
121
+ "teasing": "teas",
122
+ "conferences": "confer",
123
+ "ripeness": "ripe",
124
+ "suspenders": "suspend",
125
+ "earnings": "earn",
126
+ "reporters": "report",
127
+ "kid's": "kid",
128
+ "containers": "contain",
129
+ "corgie": "corgi",
130
+ "porche": "porch",
131
+ "microwaves": "microwave",
132
+ "batter's": "batter",
133
+ "sadness": "sad",
134
+ "apartments": "apart",
135
+ "oxygenize": "oxygen",
136
+ "striping": "stripe",
137
+ "purring": "pure",
138
+ "professionals": "profession",
139
+ "piping": "pipe",
140
+ "farmer's": "farmer",
141
+ "potatoe": "potato",
142
+ "emirates": "emir",
143
+ "womens": "women",
144
+ "veteran's": "veteran",
145
+ "wilderness": "wilder",
146
+ "propellers": "propel",
147
+ "alpes": "alp",
148
+ "charioteering": "chariot",
149
+ "swining": "swine",
150
+ "illness": "ill",
151
+ "crepte": "crept",
152
+ "adhesives": "adhesive",
153
+ "regent's": "regent",
154
+ "decorations": "decor",
155
+ "rabbies": "rabbi",
156
+ "overseas": "oversea",
157
+ "travellers": "travel",
158
+ "casings": "case",
159
+ "smugness": "smug",
160
+ "doves": "dove",
161
+ "nationals": "nation",
162
+ "mustange": "mustang",
163
+ "ringe": "ring",
164
+ "gondoliere": "gondolier",
165
+ "vacationing": "vacate",
166
+ "reminders": "remind",
167
+ "baldness": "bald",
168
+ "settings": "set",
169
+ "glaced": "glace",
170
+ "coniferous": "conifer",
171
+ "revelations": "revel",
172
+ "personals": "person",
173
+ "daughter's": "daughter",
174
+ "badness": "bad",
175
+ "projections": "project",
176
+ "polarizing": "polar",
177
+ "vandalizers": "vandal",
178
+ "minerals": "miner",
179
+ "protesters": "protest",
180
+ "controllers": "control",
181
+ "weddings": "wed",
182
+ "sometimes": "sometime",
183
+ "earing": "ear",
184
+ }
185
+
186
+
187
+ class OKVQAStemmer:
188
+ """Stemmer to match OKVQA v1.1 procedure."""
189
+
190
+ def __init__(self):
191
+ self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer()
192
+
193
+ def stem(self, input_string):
194
+ """Apply stemming."""
195
+ word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string))
196
+ stemmed_words = []
197
+ for w, p in word_and_pos:
198
+ if w in _MANUAL_MATCHES:
199
+ w = _MANUAL_MATCHES[w]
200
+ elif w.endswith("ing"):
201
+ w = self._wordnet_lemmatizer.lemmatize(w, VERB)
202
+ elif p.startswith("NNS") or p.startswith("NNPS"):
203
+ w = inflection.singularize(w)
204
+ stemmed_words.append(w)
205
+ return " ".join(stemmed_words)
206
+
207
+
208
+ stemmer = OKVQAStemmer()
209
+
210
+
211
+ def postprocess_ok_vqa_generation(prediction) -> str:
212
+ prediction_stem = stemmer.stem(prediction)
213
+ return prediction_stem
multimodal/build/lib/open_flamingo/eval/task/__init__.py ADDED
File without changes
multimodal/build/lib/open_flamingo/eval/task/caption.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lavis.datasets.builders import load_dataset
2
+ import torch
3
+ import more_itertools
4
+ from tqdm import tqdm
5
+ from coco_metric import compute_cider, postprocess_captioning_generation
6
+ import json
7
+ import time
8
+ import os
9
+ from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
10
+ from PIL import Image
11
+
12
+ class VisualLogitsProcessor(LogitsProcessor):
13
+ def __init__(self, tokenizer):
14
+ super().__init__()
15
+ self.tokenizer = tokenizer
16
+ self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
17
+ self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
18
+ self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
19
+ self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
20
+ self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
21
+ self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
22
+ self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
23
+ self.topk = 2
24
+
25
+ def __call__(self, input_ids, scores):
26
+ # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
27
+ # import pdb; pdb.set_trace()
28
+ if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
29
+ scores[0, self.object_token_id] = 1000
30
+ if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
31
+ if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
32
+ # print("generate a previsual token next")
33
+ scores[0, self.previsual_token_id] = 1000
34
+ elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
35
+ # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
36
+ scores[0, self.eos_token_id] = 1000
37
+ elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
38
+ # print("generate a visual token next")
39
+ scores[0, self.visual_token_id] = 1000
40
+ return scores
41
+
42
+
43
+ def prepare_batch_images(batch, image_processor):
44
+ batch_images = None
45
+ for b in batch:
46
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
47
+ if batch_images is None:
48
+ batch_images = b_image
49
+ else:
50
+ batch_images = torch.cat([batch_images, b_image], dim=0)
51
+ return batch_images
52
+
53
+
54
+ def captioner(
55
+ model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
56
+ """Evaluate a model on COCO dataset.
57
+ Returns:
58
+ float: CIDEr score
59
+
60
+ """
61
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
62
+ model.eval()
63
+ # model.eval().cuda()
64
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
65
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
66
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
67
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
68
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
69
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
70
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
71
+ box_token = "<|#box#|>"
72
+ prebox_token = "<|#prebox#|>"
73
+ endofobject_token = "<|#endofobject#|>"
74
+ object_token = "<|#object#|>"
75
+ ori_prompt_length = len(input_ids[0])
76
+ have_prebox = False
77
+ out_image = None
78
+ while True:
79
+ batch_images = batch_images
80
+ input_ids = input_ids
81
+ attention_mask = attention_mask
82
+ image_start_index_list = image_start_index_list
83
+ image_nums = image_nums
84
+ if debug:
85
+ print("input--->",tokenizer.decode(input_ids[0]))
86
+ p1 = MinNewTokensLengthLogitsProcessor(
87
+ prompt_length_to_skip=input_ids.shape[-1],
88
+ min_new_tokens=5,
89
+ eos_token_id=bos_token_id,
90
+ )
91
+ with torch.inference_mode():
92
+ outputs = model.generate(
93
+ batch_images,
94
+ input_ids,
95
+ attention_mask=attention_mask,
96
+ max_new_tokens=20,
97
+ # min_new_tokens=8,
98
+ num_beams=1,
99
+ # length_penalty=0,
100
+ image_start_index_list=image_start_index_list,
101
+ image_nums=image_nums,
102
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
103
+ logits_processor_list=[p1, visual_logits_processor],
104
+ )
105
+ if debug:
106
+ print("outputs--->",tokenizer.decode(outputs[0]))
107
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
108
+ prompt = tokenizer.decode(outputs.clone()[0])
109
+ is_visual = (outputs[0, -2] == visual_token_id)
110
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
111
+ encodings = tokenizer(
112
+ batch_text,
113
+ padding="longest",
114
+ truncation=True,
115
+ return_tensors="pt",
116
+ max_length=2000,
117
+ )
118
+ input_ids = encodings["input_ids"]
119
+ attention_mask = encodings["attention_mask"]
120
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
121
+ image_start_index_list = [[x] for x in image_start_index_list]
122
+ image_nums = [1] * len(input_ids)
123
+ if debug:
124
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
125
+ with torch.no_grad():
126
+ outputs = model(
127
+ vision_x=batch_images,
128
+ lang_x=input_ids,
129
+ attention_mask=attention_mask,
130
+ image_nums=image_nums,
131
+ image_start_index_list=image_start_index_list,
132
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
133
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
134
+ )
135
+ boxes = outputs["boxes"]
136
+ scores = outputs["scores"]
137
+ # if not model.valid:
138
+ # import pdb; pdb.set_trace()
139
+ if boxes is not None:
140
+ if is_visual:
141
+ if have_prebox:
142
+ added_bbox_list.pop()
143
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
144
+ have_prebox = False
145
+ if debug:
146
+ print("find previsual and remove it--->", prompt)
147
+ first_box = boxes[scores.argmax()]
148
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
149
+ prompt = prompt[:-len(tokenizer.eos_token)]
150
+ prompt += box_token + endofobject_token
151
+ if debug:
152
+ print("after inserting visual---->", prompt)
153
+ else:
154
+ import numpy as np
155
+ import cv2
156
+ open_cv_image = np.array(image_ori)
157
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
158
+ for i, pre_box in enumerate(boxes):
159
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
160
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
161
+ # exit()
162
+ pre_box = boxes[scores.argmax()]
163
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
164
+ prompt = prompt[:-len(tokenizer.eos_token)]
165
+ prompt += prebox_token + object_token
166
+ have_prebox = True
167
+ if debug:
168
+ print("after inserting previsual---->", prompt)
169
+ else:
170
+ if debug:
171
+ import pdb;pdb.set_trace()
172
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
173
+ else:
174
+ break
175
+ outputs = outputs[:, ori_prompt_length:]
176
+ outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "")
177
+ # new_predictions = [
178
+ # postprocess_captioning_generation(out).replace('"', "")
179
+ # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
180
+ # ]
181
+ # import pdb; pdb.set_trace()
182
+ return outputs, out_image
183
+
184
+
185
+ def evaluate_coco_flickr(
186
+ model,
187
+ tokenizer,
188
+ image_processor,
189
+ batch_size,
190
+ is_flickr=False,
191
+ vis_embed_size=None,
192
+ rank=0,
193
+ world_size=1,
194
+ id=0,
195
+ debug=False,
196
+ ):
197
+ """Evaluate a model on COCO dataset.
198
+ Returns:
199
+ float: CIDEr score
200
+
201
+ """
202
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
203
+ coco_dataset = load_dataset("coco_caption")
204
+ eval_dataset = coco_dataset["test"]
205
+ model.eval().cuda()
206
+ predictions = dict()
207
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
208
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
209
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
210
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
211
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
212
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
213
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
214
+ box_token = "<|#box#|>"
215
+ prebox_token = "<|#prebox#|>"
216
+ endofobject_token = "<|#endofobject#|>"
217
+ object_token = "<|#object#|>"
218
+ cnt = 0
219
+ if world_size > 1:
220
+ torch.distributed.barrier()
221
+ desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
222
+ for ii, batch in enumerate(more_itertools.chunked(
223
+ tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
224
+ )):
225
+ if ii % world_size != rank:
226
+ continue
227
+ cnt += len(batch)
228
+ batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
229
+ batch_images = prepare_batch_images(
230
+ batch=batch,
231
+ image_processor=image_processor,
232
+ ).cuda()
233
+ prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
234
+ added_bbox_list = []
235
+ batch_text = [prompt for _ in batch]
236
+ encodings = tokenizer(
237
+ batch_text,
238
+ padding="longest",
239
+ truncation=True,
240
+ return_tensors="pt",
241
+ max_length=2000,
242
+ )
243
+ ori_prompt_length = len(encodings["input_ids"][0])
244
+ have_prebox = False
245
+ while True:
246
+ batch_text = [prompt for _ in batch]
247
+ encodings = tokenizer(
248
+ batch_text,
249
+ padding="longest",
250
+ truncation=True,
251
+ return_tensors="pt",
252
+ max_length=2000,
253
+ )
254
+ input_ids = encodings["input_ids"].cuda()
255
+ attention_mask = encodings["attention_mask"].cuda()
256
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
257
+ image_start_index_list = [[x] for x in image_start_index_list]
258
+ image_nums = [1] * len(input_ids)
259
+ if debug:
260
+ print("input--->",tokenizer.decode(input_ids[0]))
261
+ p1 = MinNewTokensLengthLogitsProcessor(
262
+ prompt_length_to_skip=input_ids.shape[-1],
263
+ min_new_tokens=5,
264
+ eos_token_id=bos_token_id,
265
+ )
266
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
267
+ outputs = model.generate(
268
+ batch_images,
269
+ input_ids,
270
+ attention_mask=attention_mask,
271
+ max_new_tokens=20,
272
+ # min_new_tokens=8,
273
+ num_beams=1,
274
+ # length_penalty=0,
275
+ image_start_index_list=image_start_index_list,
276
+ image_nums=image_nums,
277
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
278
+ logits_processor_list=[p1, visual_logits_processor],
279
+ )
280
+ if debug:
281
+ print("outputs--->",tokenizer.decode(outputs[0]))
282
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
283
+ prompt = tokenizer.decode(outputs.clone()[0])
284
+ is_visual = (outputs[0, -2] == visual_token_id)
285
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
286
+ encodings = tokenizer(
287
+ batch_text,
288
+ padding="longest",
289
+ truncation=True,
290
+ return_tensors="pt",
291
+ max_length=2000,
292
+ )
293
+ input_ids = encodings["input_ids"].cuda()
294
+ attention_mask = encodings["attention_mask"].cuda()
295
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
296
+ image_start_index_list = [[x] for x in image_start_index_list]
297
+ image_nums = [1] * len(input_ids)
298
+ if debug:
299
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
300
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
301
+ outputs = model(
302
+ vision_x=batch_images,
303
+ lang_x=input_ids,
304
+ attention_mask=attention_mask,
305
+ image_nums=image_nums,
306
+ image_start_index_list=image_start_index_list,
307
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
308
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
309
+ )
310
+ boxes = outputs["boxes"]
311
+ scores = outputs["scores"]
312
+ # if not model.valid:
313
+ # import pdb; pdb.set_trace()
314
+ if boxes is not None:
315
+ if is_visual:
316
+ if have_prebox:
317
+ added_bbox_list.pop()
318
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
319
+ have_prebox = False
320
+ if debug:
321
+ print("find previsual and remove it--->", prompt)
322
+ first_box = boxes[scores.argmax()]
323
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
324
+ prompt = prompt[:-len(tokenizer.eos_token)]
325
+ prompt += box_token + endofobject_token
326
+ if debug:
327
+ print("after inserting visual---->", prompt)
328
+ else:
329
+ import numpy as np
330
+ import cv2
331
+ open_cv_image = np.array(batch[0]["image"])
332
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
333
+ for i, pre_box in enumerate(boxes):
334
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
335
+ cv2.imwrite("Atest.png", open_cv_image)
336
+ exit()
337
+ pre_box = boxes[scores.argmax()]
338
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
339
+ prompt = prompt[:-len(tokenizer.eos_token)]
340
+ prompt += prebox_token + object_token
341
+ have_prebox = True
342
+ if debug:
343
+ print("after inserting previsual---->", prompt)
344
+ else:
345
+ import pdb;pdb.set_trace()
346
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
347
+ else:
348
+ break
349
+ outputs = outputs[:, ori_prompt_length:]
350
+ new_predictions = [
351
+ postprocess_captioning_generation(out).replace('"', "")
352
+ for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
353
+ ]
354
+ # import pdb; pdb.set_trace()
355
+ if rank == 0:
356
+ tqdm.write(new_predictions[0])
357
+ for i, sample in enumerate(batch):
358
+ predictions[int(sample["image_id"])] = {
359
+ "caption": new_predictions[i],
360
+ }
361
+ print(new_predictions)
362
+ exit()
363
+ results_path = (
364
+ f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
365
+ if is_flickr
366
+ else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
367
+ )
368
+ with open(results_path, "w") as f:
369
+ f.write(
370
+ json.dumps(
371
+ [
372
+ {"image_id": k, "caption": predictions[k]["caption"]}
373
+ for k in predictions
374
+ ],
375
+ indent=2,
376
+ )
377
+ )
378
+ print("save to", results_path)
379
+ del predictions
380
+ time.sleep(10)
381
+ if world_size > 1:
382
+ torch.distributed.barrier()
383
+ if rank == 0:
384
+ print(f"evaluate on rank {rank}. world size is {world_size}")
385
+ predictions = []
386
+ for rank_i in range(world_size):
387
+ part_results_path = (
388
+ f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
389
+ if is_flickr
390
+ else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
391
+ )
392
+ print("load", part_results_path)
393
+ predictions.extend(json.load(open(part_results_path)))
394
+ os.remove(part_results_path)
395
+ print("num:", len(predictions))
396
+ results_path = (
397
+ f"flickrresults_{lang_encoder_name}.json"
398
+ if is_flickr
399
+ else f"cocoresults_{lang_encoder_name}.json"
400
+ )
401
+ json.dump(predictions, open(results_path, "w"), indent=2)
402
+
403
+ metrics = compute_cider(
404
+ result_path=results_path,
405
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
406
+ )
407
+ metrics["CIDEr"] *= 100
408
+ os.makedirs("eval_results", exist_ok=True)
409
+ acc = metrics["CIDEr"]
410
+ with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
411
+ f.write(json.dumps(predictions, indent=2))
412
+
413
+ # delete the temporary file
414
+ os.remove(results_path)
415
+ else:
416
+ metrics = {}
417
+ metrics["CIDEr"] = 0.0
418
+
419
+ return metrics["CIDEr"]
multimodal/build/lib/open_flamingo/eval/task/caption_chat.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import more_itertools
4
+ from tqdm import tqdm
5
+ import json
6
+ import time
7
+ import os
8
+ from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
9
+ from PIL import Image
10
+
11
+ class VisualLogitsProcessor(LogitsProcessor):
12
+ def __init__(self, tokenizer):
13
+ super().__init__()
14
+ self.tokenizer = tokenizer
15
+ self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
16
+ self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
17
+ self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
18
+ self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
19
+ self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
20
+ self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
21
+ self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
22
+ self.topk = 2
23
+
24
+ def __call__(self, input_ids, scores):
25
+ # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
26
+ # import pdb; pdb.set_trace()
27
+ if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
28
+ scores[0, self.object_token_id] = 1000
29
+ if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
30
+ if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
31
+ # print("generate a previsual token next")
32
+ scores[0, self.previsual_token_id] = 1000
33
+ elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
34
+ # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
35
+ scores[0, self.eos_token_id] = 1000
36
+ elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
37
+ # print("generate a visual token next")
38
+ scores[0, self.visual_token_id] = 1000
39
+ return scores
40
+
41
+
42
+ def prepare_batch_images(batch, image_processor):
43
+ batch_images = None
44
+ for b in batch:
45
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
46
+ if batch_images is None:
47
+ batch_images = b_image
48
+ else:
49
+ batch_images = torch.cat([batch_images, b_image], dim=0)
50
+ return batch_images
51
+
52
+
53
+ def captioner(
54
+ model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
55
+ """Evaluate a model on COCO dataset.
56
+ Returns:
57
+ float: CIDEr score
58
+
59
+ """
60
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
61
+ model.eval()
62
+ # model.eval().cuda()
63
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
64
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
65
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
66
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
67
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
68
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
69
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
70
+ box_token = "<|#box#|>"
71
+ prebox_token = "<|#prebox#|>"
72
+ endofobject_token = "<|#endofobject#|>"
73
+ object_token = "<|#object#|>"
74
+ ori_prompt_length = len(input_ids[0])
75
+ have_prebox = False
76
+ while True:
77
+ batch_images = batch_images
78
+ input_ids = input_ids
79
+ attention_mask = attention_mask
80
+ image_start_index_list = image_start_index_list
81
+ image_nums = image_nums
82
+ if debug:
83
+ print("input--->",tokenizer.decode(input_ids[0]))
84
+ p1 = MinNewTokensLengthLogitsProcessor(
85
+ prompt_length_to_skip=input_ids.shape[-1],
86
+ min_new_tokens=5,
87
+ eos_token_id=bos_token_id,
88
+ )
89
+ with torch.inference_mode():
90
+ outputs = model.generate(
91
+ batch_images,
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ max_new_tokens=20,
95
+ # min_new_tokens=8,
96
+ num_beams=1,
97
+ # length_penalty=0,
98
+ image_start_index_list=image_start_index_list,
99
+ image_nums=image_nums,
100
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
101
+ logits_processor_list=[p1, visual_logits_processor],
102
+ )
103
+ if debug:
104
+ print("outputs--->",tokenizer.decode(outputs[0]))
105
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
106
+ prompt = tokenizer.decode(outputs.clone()[0])
107
+ is_visual = (outputs[0, -2] == visual_token_id)
108
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
109
+ encodings = tokenizer(
110
+ batch_text,
111
+ padding="longest",
112
+ truncation=True,
113
+ return_tensors="pt",
114
+ max_length=2000,
115
+ )
116
+ input_ids = encodings["input_ids"]
117
+ attention_mask = encodings["attention_mask"]
118
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
119
+ image_start_index_list = [[x] for x in image_start_index_list]
120
+ image_nums = [1] * len(input_ids)
121
+ if debug:
122
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
123
+ with torch.no_grad():
124
+ outputs = model(
125
+ vision_x=batch_images,
126
+ lang_x=input_ids,
127
+ attention_mask=attention_mask,
128
+ image_nums=image_nums,
129
+ image_start_index_list=image_start_index_list,
130
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
131
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
132
+ )
133
+ boxes = outputs["boxes"]
134
+ scores = outputs["scores"]
135
+ # if not model.valid:
136
+ # import pdb; pdb.set_trace()
137
+ if boxes is not None:
138
+ if is_visual:
139
+ if have_prebox:
140
+ added_bbox_list.pop()
141
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
142
+ have_prebox = False
143
+ if debug:
144
+ print("find previsual and remove it--->", prompt)
145
+ first_box = boxes[scores.argmax()]
146
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
147
+ prompt = prompt[:-len(tokenizer.eos_token)]
148
+ prompt += box_token + endofobject_token
149
+ if debug:
150
+ print("after inserting visual---->", prompt)
151
+ else:
152
+ import numpy as np
153
+ import cv2
154
+ open_cv_image = np.array(image_ori)
155
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
156
+ for i, pre_box in enumerate(boxes):
157
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
158
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
159
+ # exit()
160
+ pre_box = boxes[scores.argmax()]
161
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
162
+ prompt = prompt[:-len(tokenizer.eos_token)]
163
+ prompt += prebox_token + object_token
164
+ have_prebox = True
165
+ if debug:
166
+ print("after inserting previsual---->", prompt)
167
+ else:
168
+ if debug:
169
+ import pdb;pdb.set_trace()
170
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
171
+ else:
172
+ break
173
+ outputs = outputs[:, ori_prompt_length:]
174
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
175
+ # new_predictions = [
176
+ # postprocess_captioning_generation(out).replace('"', "")
177
+ # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
178
+ # ]
179
+ # import pdb; pdb.set_trace()
180
+ return outputs, out_image
181
+
182
+
183
+ def evaluate_coco_flickr(
184
+ model,
185
+ tokenizer,
186
+ image_processor,
187
+ batch_size,
188
+ is_flickr=False,
189
+ vis_embed_size=None,
190
+ rank=0,
191
+ world_size=1,
192
+ id=0,
193
+ debug=False,
194
+ ):
195
+ """Evaluate a model on COCO dataset.
196
+ Returns:
197
+ float: CIDEr score
198
+
199
+ """
200
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
201
+ coco_dataset = load_dataset("coco_caption")
202
+ eval_dataset = coco_dataset["test"]
203
+ model.eval().cuda()
204
+ predictions = dict()
205
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
206
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
207
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
208
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
209
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
210
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
211
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
212
+ box_token = "<|#box#|>"
213
+ prebox_token = "<|#prebox#|>"
214
+ endofobject_token = "<|#endofobject#|>"
215
+ object_token = "<|#object#|>"
216
+ cnt = 0
217
+ if world_size > 1:
218
+ torch.distributed.barrier()
219
+ desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
220
+ for ii, batch in enumerate(more_itertools.chunked(
221
+ tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
222
+ )):
223
+ if ii % world_size != rank:
224
+ continue
225
+ cnt += len(batch)
226
+ batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
227
+ batch_images = prepare_batch_images(
228
+ batch=batch,
229
+ image_processor=image_processor,
230
+ ).cuda()
231
+ prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
232
+ added_bbox_list = []
233
+ batch_text = [prompt for _ in batch]
234
+ encodings = tokenizer(
235
+ batch_text,
236
+ padding="longest",
237
+ truncation=True,
238
+ return_tensors="pt",
239
+ max_length=2000,
240
+ )
241
+ ori_prompt_length = len(encodings["input_ids"][0])
242
+ have_prebox = False
243
+ while True:
244
+ batch_text = [prompt for _ in batch]
245
+ encodings = tokenizer(
246
+ batch_text,
247
+ padding="longest",
248
+ truncation=True,
249
+ return_tensors="pt",
250
+ max_length=2000,
251
+ )
252
+ input_ids = encodings["input_ids"].cuda()
253
+ attention_mask = encodings["attention_mask"].cuda()
254
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
255
+ image_start_index_list = [[x] for x in image_start_index_list]
256
+ image_nums = [1] * len(input_ids)
257
+ if debug:
258
+ print("input--->",tokenizer.decode(input_ids[0]))
259
+ p1 = MinNewTokensLengthLogitsProcessor(
260
+ prompt_length_to_skip=input_ids.shape[-1],
261
+ min_new_tokens=5,
262
+ eos_token_id=bos_token_id,
263
+ )
264
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
265
+ outputs = model.generate(
266
+ batch_images,
267
+ input_ids,
268
+ attention_mask=attention_mask,
269
+ max_new_tokens=20,
270
+ # min_new_tokens=8,
271
+ num_beams=1,
272
+ # length_penalty=0,
273
+ image_start_index_list=image_start_index_list,
274
+ image_nums=image_nums,
275
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
276
+ logits_processor_list=[p1, visual_logits_processor],
277
+ )
278
+ if debug:
279
+ print("outputs--->",tokenizer.decode(outputs[0]))
280
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
281
+ prompt = tokenizer.decode(outputs.clone()[0])
282
+ is_visual = (outputs[0, -2] == visual_token_id)
283
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
284
+ encodings = tokenizer(
285
+ batch_text,
286
+ padding="longest",
287
+ truncation=True,
288
+ return_tensors="pt",
289
+ max_length=2000,
290
+ )
291
+ input_ids = encodings["input_ids"].cuda()
292
+ attention_mask = encodings["attention_mask"].cuda()
293
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
294
+ image_start_index_list = [[x] for x in image_start_index_list]
295
+ image_nums = [1] * len(input_ids)
296
+ if debug:
297
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
298
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
299
+ outputs = model(
300
+ vision_x=batch_images,
301
+ lang_x=input_ids,
302
+ attention_mask=attention_mask,
303
+ image_nums=image_nums,
304
+ image_start_index_list=image_start_index_list,
305
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
306
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
307
+ )
308
+ boxes = outputs["boxes"]
309
+ scores = outputs["scores"]
310
+ # if not model.valid:
311
+ # import pdb; pdb.set_trace()
312
+ if boxes is not None:
313
+ if is_visual:
314
+ if have_prebox:
315
+ added_bbox_list.pop()
316
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
317
+ have_prebox = False
318
+ if debug:
319
+ print("find previsual and remove it--->", prompt)
320
+ first_box = boxes[scores.argmax()]
321
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
322
+ prompt = prompt[:-len(tokenizer.eos_token)]
323
+ prompt += box_token + endofobject_token
324
+ if debug:
325
+ print("after inserting visual---->", prompt)
326
+ else:
327
+ import numpy as np
328
+ import cv2
329
+ open_cv_image = np.array(batch[0]["image"])
330
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
331
+ for i, pre_box in enumerate(boxes):
332
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
333
+ cv2.imwrite("Atest.png", open_cv_image)
334
+ exit()
335
+ pre_box = boxes[scores.argmax()]
336
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
337
+ prompt = prompt[:-len(tokenizer.eos_token)]
338
+ prompt += prebox_token + object_token
339
+ have_prebox = True
340
+ if debug:
341
+ print("after inserting previsual---->", prompt)
342
+ else:
343
+ import pdb;pdb.set_trace()
344
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
345
+ else:
346
+ break
347
+ outputs = outputs[:, ori_prompt_length:]
348
+ new_predictions = [
349
+ postprocess_captioning_generation(out).replace('"', "")
350
+ for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
351
+ ]
352
+ # import pdb; pdb.set_trace()
353
+ if rank == 0:
354
+ tqdm.write(new_predictions[0])
355
+ for i, sample in enumerate(batch):
356
+ predictions[int(sample["image_id"])] = {
357
+ "caption": new_predictions[i],
358
+ }
359
+ print(new_predictions)
360
+ exit()
361
+ results_path = (
362
+ f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
363
+ if is_flickr
364
+ else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
365
+ )
366
+ with open(results_path, "w") as f:
367
+ f.write(
368
+ json.dumps(
369
+ [
370
+ {"image_id": k, "caption": predictions[k]["caption"]}
371
+ for k in predictions
372
+ ],
373
+ indent=2,
374
+ )
375
+ )
376
+ print("save to", results_path)
377
+ del predictions
378
+ time.sleep(10)
379
+ if world_size > 1:
380
+ torch.distributed.barrier()
381
+ if rank == 0:
382
+ print(f"evaluate on rank {rank}. world size is {world_size}")
383
+ predictions = []
384
+ for rank_i in range(world_size):
385
+ part_results_path = (
386
+ f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
387
+ if is_flickr
388
+ else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
389
+ )
390
+ print("load", part_results_path)
391
+ predictions.extend(json.load(open(part_results_path)))
392
+ os.remove(part_results_path)
393
+ print("num:", len(predictions))
394
+ results_path = (
395
+ f"flickrresults_{lang_encoder_name}.json"
396
+ if is_flickr
397
+ else f"cocoresults_{lang_encoder_name}.json"
398
+ )
399
+ json.dump(predictions, open(results_path, "w"), indent=2)
400
+
401
+ metrics = compute_cider(
402
+ result_path=results_path,
403
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
404
+ )
405
+ metrics["CIDEr"] *= 100
406
+ os.makedirs("eval_results", exist_ok=True)
407
+ acc = metrics["CIDEr"]
408
+ with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
409
+ f.write(json.dumps(predictions, indent=2))
410
+
411
+ # delete the temporary file
412
+ os.remove(results_path)
413
+ else:
414
+ metrics = {}
415
+ metrics["CIDEr"] = 0.0
416
+
417
+ return metrics["CIDEr"]
multimodal/build/lib/open_flamingo/eval/task/cola.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import webdataset as wds
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import os
8
+ import time
9
+ import cv2
10
+ import random
11
+ import math
12
+ from open_flamingo.eval.task.utils import (
13
+ get_object_from_text,
14
+ is_correct,
15
+ _eval_text_image,
16
+ get_bbox,
17
+ get_iou,
18
+ )
19
+ DATASET = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/COLA/data/COLA_multiobjects_matching_benchmark.json"
20
+ VG_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K"
21
+
22
+ def get_score(image, text, model, tokenizer, image_processor, vis_embed_size):
23
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
24
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
25
+ object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
26
+ text = text.split("#")
27
+ obj_A = text[0].strip().split(" ")
28
+ relation = text[1].strip()
29
+ obj_B = text[2].strip().split(" ")
30
+ if "computer mouse" not in text[0].strip():
31
+ attrAs = obj_A[:-1]
32
+ nounA = obj_A[-1]
33
+ else:
34
+ attrAs = obj_A[:-2]
35
+ nounA = " ".join(obj_A[-2:])
36
+ if "computer mouse" not in text[2].strip():
37
+ attrBs = obj_B[:-1]
38
+ nounB = obj_B[-1]
39
+ else:
40
+ attrBs = obj_B[:-2]
41
+ nounB = " ".join(obj_B[-2:])
42
+ # print("="*80)
43
+ # print(attrAs, nounA)
44
+ # print(attrBs, nounB)
45
+ # print(relation)
46
+ # print("="*80)
47
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
48
+
49
+
50
+ prompt1 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {nounA}<|#endofobject#|><|#visual#|>"]
51
+ boxes, scores = get_bbox(None, batch_images, prompt1, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
52
+
53
+
54
+ # open_cv_image = np.array(image)
55
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
56
+ # for pre_box in boxes:
57
+ # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
58
+
59
+ box_ppl = []
60
+ box_attr_losses = []
61
+ for box in boxes:
62
+ losses = []
63
+ for attrA in attrAs:
64
+ prompt2 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {attrA} {nounA}"]
65
+ encodings = tokenizer(
66
+ prompt2,
67
+ padding="longest",
68
+ truncation=True,
69
+ return_tensors="pt",
70
+ max_length=512,
71
+ )
72
+ input_ids = encodings["input_ids"]
73
+ attention_mask = encodings["attention_mask"]
74
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
75
+ image_start_index_list = [[x] for x in image_start_index_list]
76
+ image_nums = [1] * len(input_ids)
77
+ vision_x = batch_images.cuda()
78
+ lang_x = input_ids.cuda()
79
+ attention_mask = attention_mask.cuda()
80
+ labels = lang_x.clone()
81
+ start_idx = (labels == object_token_id).nonzero()[-1, -1]
82
+ labels[0, :start_idx+1] = -100
83
+ added_bbox_list = [torch.tensor(box / 224.0).cuda().unsqueeze(0)]
84
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
85
+ outputs = model(
86
+ vision_x=vision_x,
87
+ lang_x=lang_x,
88
+ attention_mask=attention_mask,
89
+ labels=labels,
90
+ image_nums=image_nums,
91
+ image_start_index_list=image_start_index_list,
92
+ added_bbox_list=added_bbox_list,
93
+ add_box=added_bbox_list is not None,
94
+ relations=None,
95
+ )
96
+ loss = outputs.loss
97
+ loss = (loss.sum() / (loss != 0).sum()).item()
98
+ losses.append(loss)
99
+ avg_ppl = np.array(losses).mean()
100
+ box_ppl.append(avg_ppl)
101
+ box_attr_losses.append(losses)
102
+ fit_idx = np.array(box_ppl).argmin()
103
+ fit_box = boxes[fit_idx]
104
+ fit_attr = attrAs[np.array(box_attr_losses[fit_idx]).argmin()]
105
+ first_ppl = min(box_ppl)
106
+
107
+ # open_cv_image = cv2.rectangle(open_cv_image, fit_box[:2].astype(int), fit_box[2:].astype(int), (255, 0, 0), 2)
108
+
109
+
110
+ prompt3 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {fit_attr} {nounA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> is {relation}<|#object#|><|#previsual#|>"]
111
+ boxes, scores = get_bbox([torch.tensor(fit_box / 224).cuda().unsqueeze(0)], batch_images, prompt3, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
112
+ # for i, pre_box in enumerate(boxes):
113
+ # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 0, 255), i+1)
114
+ # cv2.imwrite(f"Atest.png", open_cv_image)
115
+
116
+ box_ppl = []
117
+ for box in boxes:
118
+ losses = []
119
+ for attrB in attrBs:
120
+ prompt4 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {fit_attr} {nounA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {attrB} {nounB}"]
121
+ encodings = tokenizer(
122
+ prompt4,
123
+ padding="longest",
124
+ truncation=True,
125
+ return_tensors="pt",
126
+ max_length=512,
127
+ )
128
+ input_ids = encodings["input_ids"]
129
+ attention_mask = encodings["attention_mask"]
130
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
131
+ image_start_index_list = [[x] for x in image_start_index_list]
132
+ image_nums = [1] * len(input_ids)
133
+ vision_x = batch_images.cuda()
134
+ lang_x = input_ids.cuda()
135
+ attention_mask = attention_mask.cuda()
136
+ labels = lang_x.clone()
137
+ start_idx = (labels == object_token_id).nonzero()[-1, -1]
138
+ labels[0, :start_idx+1] = -100
139
+ added_bbox_list = [torch.tensor(fit_box / 224.0).cuda().unsqueeze(0), torch.tensor(box / 224.0).cuda().unsqueeze(0)]
140
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
141
+ outputs = model(
142
+ vision_x=vision_x,
143
+ lang_x=lang_x,
144
+ attention_mask=attention_mask,
145
+ labels=labels,
146
+ image_nums=image_nums,
147
+ image_start_index_list=image_start_index_list,
148
+ added_bbox_list=added_bbox_list,
149
+ add_box=added_bbox_list is not None,
150
+ relations=None,
151
+ )
152
+ loss = outputs.loss
153
+ loss = (loss.sum() / (loss != 0).sum()).item()
154
+ losses.append(loss)
155
+ avg_ppl = np.array(losses).mean()
156
+ box_ppl.append(avg_ppl)
157
+ second_ppl = (np.array(box_ppl) * np.array(scores)).sum() / sum(scores)
158
+ return (first_ppl + second_ppl) / 2
159
+
160
+
161
+ def evaluate_cola(
162
+ model,
163
+ tokenizer,
164
+ image_processor,
165
+ vis_embed_size=None,
166
+ rank=0,
167
+ world_size=1,
168
+ id=0,
169
+ debug=False,
170
+ ):
171
+ dataset_name = "cola"
172
+ dataset = json.load(open(DATASET))
173
+ model = model.cuda().eval()
174
+ correct = 0
175
+ total = 0
176
+ pbar = tqdm(dataset, disable=(rank != 0))
177
+ for ii, sample in enumerate(pbar):
178
+ if ii % world_size != rank:
179
+ continue
180
+ image1 = Image.open(os.path.join(VG_ROOT, os.path.basename(sample[0]))).convert("RGB").resize((224, 224))
181
+ text1 = sample[1]
182
+ image2 = Image.open(os.path.join(VG_ROOT, os.path.basename(sample[2]))).convert("RGB").resize((224, 224))
183
+ text2 = sample[3]
184
+ score11 = -get_score(image1, text1, model, tokenizer, image_processor, vis_embed_size)
185
+ score12 = -get_score(image1, text2, model, tokenizer, image_processor, vis_embed_size)
186
+ score21 = -get_score(image2, text1, model, tokenizer, image_processor, vis_embed_size)
187
+ score22 = -get_score(image2, text2, model, tokenizer, image_processor, vis_embed_size)
188
+ if rank == 0:
189
+ tqdm.write(f"{score11:.2f} {score12:.2f} {score21:.2f} {score22:.2f}")
190
+ if score11 > score21 and score22 > score12:
191
+ correct += 1
192
+ total += 1
193
+ pbar.set_description(f"{correct / total:.2f}")
194
+ print(rank, correct / total)
195
+
196
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
197
+ f.write(json.dumps([total, correct]))
198
+ if world_size > 1:
199
+ torch.distributed.barrier()
200
+ if rank == 0:
201
+ total = 0
202
+ correct = 0
203
+ print(f"evaluate on rank {rank}. world size is {world_size}")
204
+ for rank_i in range(world_size):
205
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
206
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
207
+ total += total_part
208
+ correct += correct_part
209
+ score = correct / total
210
+ print("score:", score)
211
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}_{total}"), "w") as f:
212
+ pass
213
+ else:
214
+ score = 0.0
215
+ if world_size > 1:
216
+ torch.distributed.barrier()
217
+ return score
218
+
219
+ if __name__ == "__main__":
220
+ evaluate_cola(None, None, None)
multimodal/build/lib/open_flamingo/eval/task/crepe.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import webdataset as wds
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import os
8
+ import time
9
+ import cv2
10
+ import random
11
+ import pandas as pd
12
+ from .vl_checklist import _eval_text_image
13
+ DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/crepe/prod_hard_negatives"
14
+
15
+
16
+ def evaluate_crepe(
17
+ model,
18
+ tokenizer,
19
+ image_processor,
20
+ vis_embed_size=None,
21
+ rank=0,
22
+ world_size=1,
23
+ id=0,
24
+ subset=True,
25
+ debug=False,
26
+ level=4,
27
+ type="swap",
28
+ ):
29
+ if rank == 0:
30
+ tqdm.write(f"level: {level}")
31
+ tqdm.write(f"type: {type}")
32
+ dataset_name = "crepe"
33
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
34
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
35
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
36
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
37
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
38
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
39
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
40
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
41
+ model.eval().cuda()
42
+ total = 0
43
+ correct = 0
44
+ assert type in ["swap"]
45
+ assert 4 <= level <= 12
46
+ filename = os.path.join(DATASET_ROOT, type, f"prod_vg_hard_negs_{type}_complexity_{level}.csv")
47
+ df = pd.read_csv(filename)
48
+ pbar = tqdm(df.iterrows(), disable=(rank != 0))
49
+ for ii, sample in pbar:
50
+ if ii % world_size != rank:
51
+ continue
52
+ text = sample.caption
53
+ image_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K/{}.jpg".format(sample.image_id)
54
+ x = sample.x
55
+ y = sample.y
56
+ width = sample.width
57
+ height = sample.height
58
+ image = Image.open(image_path).convert("RGB")
59
+ image = image.crop((x, y, x+width, y+height))
60
+ image = image.resize((224, 224))
61
+ final_rank, final_ranks = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
62
+ if final_rank is None:
63
+ continue
64
+ correct += int((np.array(final_ranks) < 10).sum())
65
+ total += len(final_ranks)
66
+ if debug:
67
+ tqdm.write("="*80)
68
+ pbar.set_description(f"{text} | score: {correct / total:.4f} | {final_rank} | {final_ranks}")
69
+
70
+
71
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
72
+ f.write(json.dumps([total, correct]))
73
+ if world_size > 1:
74
+ torch.distributed.barrier()
75
+ if rank == 0:
76
+ total = 0
77
+ correct = 0
78
+ print(f"evaluate on rank {rank}. world size is {world_size}")
79
+ for rank_i in range(world_size):
80
+ [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
81
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
82
+ total += total_part
83
+ correct += correct_part
84
+ score = correct / total
85
+ print("score:", score, "total:", total)
86
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
87
+ pass
88
+ else:
89
+ score = 0.0
90
+ if world_size > 1:
91
+ torch.distributed.barrier()
92
+ return score
93
+
multimodal/build/lib/open_flamingo/eval/task/gqa.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import json
3
+ from PIL import Image
4
+ import os
5
+ import torch
6
+ import more_itertools
7
+ from tqdm import tqdm
8
+ import time
9
+ from vqa_metric import compute_gqa_accuracy
10
+ import string
11
+ import uuid
12
+ import numpy as np
13
+ import cv2
14
+ from open_flamingo.eval.task.utils import get_bbox
15
+
16
+ class GQADataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ image_dir_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/images",
20
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/testdev_balanced_questions.json",
21
+ ):
22
+ annotations = json.load(open(annotations_path))
23
+ self.questions = []
24
+ self.answers = []
25
+ self.image_paths = []
26
+ self.question_ids = []
27
+ for anno_id in annotations:
28
+ question = annotations[anno_id]["question"]
29
+ imageId = annotations[anno_id]["imageId"]
30
+ answer = annotations[anno_id]["answer"]
31
+ self.questions.append(question)
32
+ self.answers.append(answer)
33
+ self.image_paths.append(os.path.join(image_dir_path, "{}.jpg".format(imageId)))
34
+ self.question_ids.append(anno_id)
35
+ # print(annotations[anno_id]["types"])
36
+ self.vqa_dataset = "gqa"
37
+
38
+ def __len__(self):
39
+ return len(self.questions)
40
+
41
+ def __getitem__(self, idx):
42
+ question = self.questions[idx]
43
+ question_id = self.question_ids[idx]
44
+ answer = self.answers[idx]
45
+ img_path = self.image_paths[idx]
46
+ image = Image.open(img_path)
47
+ return {
48
+ "image": image,
49
+ "question": question,
50
+ "answers": answer,
51
+ "question_id": question_id,
52
+ }
53
+
54
+
55
+ def prepare_batch_images(batch, image_processor):
56
+ batch_images = None
57
+ for b in batch:
58
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
59
+ if batch_images is None:
60
+ batch_images = b_image
61
+ else:
62
+ batch_images = torch.cat([batch_images, b_image], dim=0)
63
+ return batch_images
64
+
65
+
66
+
67
+ def evaluate_gqa(
68
+ model,
69
+ tokenizer,
70
+ image_processor,
71
+ batch_size=1,
72
+ vis_embed_size=None,
73
+ rank=0,
74
+ world_size=1,
75
+ id=0,
76
+ ):
77
+ """
78
+ Evaluate a model on VQA datasets. Currently supports VQA v2.0.
79
+
80
+ Args:
81
+ model (nn.Module): model to evaluate
82
+ tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
83
+ image_processor : image processor for the model
84
+ batch_size (int): batch size
85
+ image_dir_path (str): path to image directory
86
+ questions_json_path (str): path to questions json file
87
+ annotations_json_path (str): path to annotations json file
88
+ seed (int, optional): random seed. Defaults to 42.
89
+ max_generation_length (int, optional): max generation length. Defaults to 5.
90
+ num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
91
+ length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
92
+ num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
93
+ query_set_size (int, optional): size of the query set. Defaults to 2048.
94
+ num_shots (int, optional): number of shots to use. Defaults to 8.
95
+ device (int, optional): device to use. Defaults to -1 (cpu).
96
+ num_workers (int, optional): number of workers to use. Defaults to 4.
97
+ vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
98
+ Returns:
99
+ float: accuracy score
100
+ """
101
+ assert batch_size == 1
102
+ vqa_dataset = "gqa"
103
+ eval_dataset = GQADataset()
104
+ object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
105
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
106
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
107
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
108
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
109
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
110
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
111
+ def get_prompt(sample):
112
+ return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
113
+ model.eval().cuda()
114
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
115
+ predictions = []
116
+ if batch_size != 1:
117
+ tokenizer.padding_side = "left"
118
+ if world_size > 1:
119
+ torch.distributed.barrier()
120
+ this_tot = 0
121
+ for ii, batch in enumerate(more_itertools.chunked(
122
+ tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size,
123
+ )):
124
+ if ii % world_size != rank:
125
+ continue
126
+ batch[0]["image"] = batch[0]["image"].resize((224, 224))
127
+ batch_images = prepare_batch_images(
128
+ batch=batch,
129
+ image_processor=image_processor,
130
+ ).cuda()
131
+ batch_text = [get_prompt(s) for s in batch]
132
+ encodings = tokenizer(
133
+ batch_text,
134
+ return_tensors="pt",
135
+ padding="longest",
136
+ truncation=True,
137
+ max_length=2000,
138
+ )
139
+ input_ids = encodings["input_ids"].cuda()
140
+ attention_mask = encodings["attention_mask"].cuda()
141
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
142
+ image_start_index_list = [[x] for x in image_start_index_list]
143
+ image_nums = [1] * len(input_ids)
144
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
145
+ outputs = model.generate(
146
+ batch_images,
147
+ input_ids,
148
+ attention_mask=attention_mask,
149
+ max_new_tokens=10,
150
+ min_length=1,
151
+ num_beams=1,
152
+ # length_penalty=0,
153
+ image_start_index_list=image_start_index_list,
154
+ image_nums=image_nums,
155
+ added_bbox_list=None,
156
+ return_dict_in_generate=True,
157
+ output_scores=True,
158
+ )
159
+ scores = outputs.scores
160
+ outputs = outputs.sequences[:, len(input_ids[0]) :]
161
+ if object_token_id in scores[0][0].sort(descending=True).indices[:5]:
162
+ sample = batch[0]
163
+ # print("="*80)
164
+ # print("sample:", batch, scores[0][0].sort(descending=True).indices[:10].tolist().index(object_token_id))
165
+ prompt1 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:<|#object#|><|#previsual#|>"]
166
+ boxes, scores = get_bbox(None, batch_images, prompt1, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
167
+ # open_cv_image = np.array(sample["image"])
168
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
169
+ # cv2.imwrite(f"Atest_ori.png", open_cv_image)
170
+ # open_cv_image = cv2.rectangle(open_cv_image, boxes[0][:2].astype(int), boxes[0][2:].astype(int), (0, 255, 0), 2)
171
+ # print(scores)
172
+ # cv2.imwrite(f"Atest.png", open_cv_image)
173
+ if boxes is not None and len(boxes) > 0:
174
+ prompt2 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer: it is<|#object#|><|#previsual#|><|#prebox#|><|#object#|> a"]
175
+ encodings = tokenizer(
176
+ prompt2,
177
+ return_tensors="pt",
178
+ padding="longest",
179
+ truncation=True,
180
+ max_length=2000,
181
+ )
182
+ input_ids = encodings["input_ids"].cuda()
183
+ attention_mask = encodings["attention_mask"].cuda()
184
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
185
+ image_start_index_list = [[x] for x in image_start_index_list]
186
+ image_nums = [1] * len(input_ids)
187
+ added_bbox_list = [torch.tensor(boxes[0]/224.0).cuda().unsqueeze(0).clamp(0, 0.99)]
188
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
189
+ outputs = model.generate(
190
+ batch_images,
191
+ input_ids,
192
+ attention_mask=attention_mask,
193
+ max_new_tokens=10,
194
+ min_length=1,
195
+ num_beams=1,
196
+ image_start_index_list=image_start_index_list,
197
+ image_nums=image_nums,
198
+ added_bbox_list=added_bbox_list,
199
+ eos_token_id=(endofobject_token_id),
200
+ )
201
+ outputs = outputs[:, len(input_ids[0]) :]
202
+ # print("previsual===>{}".format(tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower().strip(string.punctuation+" ")))
203
+
204
+ # postprocess begin
205
+ new_predictions = [
206
+ out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
207
+ ]
208
+ this_tot += 1
209
+ predictions.extend(
210
+ [
211
+ {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
212
+ for p, sample in zip(new_predictions, batch)
213
+ ]
214
+ )
215
+ with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
216
+ f.write(json.dumps(predictions))
217
+ print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
218
+
219
+ time.sleep(10)
220
+ if world_size > 1:
221
+ torch.distributed.barrier()
222
+ if rank == 0:
223
+ print(f"evaluate on rank {rank}. world size is {world_size}")
224
+ predictions = []
225
+ for rank_i in range(world_size):
226
+ print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
227
+ predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
228
+ os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
229
+ print("num:", len(predictions))
230
+ # save the predictions to a temporary file
231
+ random_uuid = str(uuid.uuid4())
232
+ with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
233
+ f.write(json.dumps(predictions, indent=4))
234
+
235
+ acc = compute_gqa_accuracy(predictions)
236
+ print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
237
+ os.makedirs("eval_results", exist_ok=True)
238
+ with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
239
+ f.write(json.dumps(predictions, indent=2))
240
+
241
+ # delete the temporary file
242
+ os.remove(f"{vqa_dataset}results_{random_uuid}.json")
243
+ else:
244
+ time.sleep(5)
245
+ acc = 0.0
246
+ if world_size > 1:
247
+ torch.distributed.barrier()
248
+ return acc
multimodal/build/lib/open_flamingo/eval/task/mmbench.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import random
4
+
5
+ import pandas as pd
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from open_flamingo.eval.task.utils import get_object_from_text
9
+
10
+ def decode_base64_to_image(base64_string):
11
+ image_data = base64.b64decode(base64_string)
12
+ image = Image.open(io.BytesIO(image_data))
13
+ return image
14
+
15
+ class MMBenchDataset(Dataset):
16
+ def __init__(self,
17
+ data_file,
18
+ sys_prompt='There are several options:'):
19
+ self.df = pd.read_csv(data_file, sep='\t')
20
+ self.sys_prompt = sys_prompt
21
+
22
+ def __len__(self):
23
+ return len(self.df)
24
+
25
+ def __getitem__(self, idx):
26
+ index = self.df.iloc[idx]['index']
27
+ image = self.df.iloc[idx]['image']
28
+ image = decode_base64_to_image(image)
29
+ question = self.df.iloc[idx]['question']
30
+ answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None
31
+ catetory = self.df.iloc[idx]['category']
32
+ l2_catetory = self.df.iloc[idx]['l2-category']
33
+
34
+ option_candidate = ['A', 'B', 'C', 'D', 'E']
35
+ options = {
36
+ cand: self.load_from_df(idx, cand)
37
+ for cand in option_candidate
38
+ if self.load_from_df(idx, cand) is not None
39
+ }
40
+ options_prompt = f'{self.sys_prompt}\n'
41
+ for key, item in options.items():
42
+ options_prompt += f'{key}. {item}\n'
43
+
44
+ hint = self.load_from_df(idx, 'hint')
45
+ data = {
46
+ 'img': image,
47
+ 'question': question,
48
+ 'answer': answer,
49
+ 'options': options_prompt,
50
+ 'category': catetory,
51
+ 'l2-category': l2_catetory,
52
+ 'options_dict': options,
53
+ 'index': index,
54
+ 'context': hint,
55
+ }
56
+ return data
57
+ def load_from_df(self, idx, key):
58
+ if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
59
+ return self.df.iloc[idx][key]
60
+ else:
61
+ return None
62
+
63
+
64
+ def evaluate_mmbench(
65
+ model,
66
+ tokenizer,
67
+ image_processor,
68
+ batch_size=1,
69
+ image_dir_path=None,
70
+ questions_json_path=None,
71
+ annotations_json_path=None,
72
+ vis_embed_size=None,
73
+ rank=0,
74
+ world_size=1,
75
+ id=0,
76
+ ):
77
+ dataset_name = "mmbench"
78
+ dataset = MMBenchDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/mmbench/mmbench_dev_20230712.tsv")
79
+ for sample in dataset:
80
+ print(sample)
81
+
82
+
83
+ if __name__ == '__main__':
84
+ evaluate_mmbench(None, None, None)
multimodal/build/lib/open_flamingo/eval/task/reg.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import base64
6
+ import numpy as np
7
+ import time
8
+ import json
9
+ import os
10
+ import cv2
11
+ from coco_metric import compute_cider
12
+ import random
13
+ import pickle
14
+
15
+ def evaluate_reg(
16
+ model,
17
+ tokenizer,
18
+ image_processor,
19
+ vis_embed_size=None,
20
+ rank=0,
21
+ world_size=1,
22
+ id=0,
23
+ ):
24
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
25
+ dataset_name = "refcocog"
26
+ pkl_file = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_data.pkl"
27
+ try:
28
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
29
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
30
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
31
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
32
+ except:
33
+ pass
34
+
35
+ model.eval().cuda()
36
+ if world_size > 1:
37
+ torch.distributed.barrier()
38
+ this_tot = 0
39
+ predictions = []
40
+ D = pickle.load(open(pkl_file, "rb"))
41
+ lines = []
42
+ data = D["data"]
43
+ uniq_id_to_text = D["uniq_id_to_text"]
44
+ uniq_id_to_image = D["uniq_id_to_image"]
45
+ uniq_id_to_image_id = D["uniq_id_to_image_id"]
46
+ for image_id in data:
47
+ for region in data[image_id]:
48
+ uniq_id = data[image_id][region][0]
49
+ lines.append([uniq_id, uniq_id_to_image_id[uniq_id], [uniq_id_to_text[r] for r in data[image_id][region]], region, uniq_id_to_image[uniq_id]])
50
+ print("total data:", len(lines))
51
+ # lines = lines[:20]
52
+ pbar = tqdm(lines, disable=(rank != 0))
53
+ for ii, line in enumerate(pbar):
54
+ if ii % world_size != rank:
55
+ continue
56
+ uniq_id, image_id, text, region_coord, image = line
57
+ gt_box = np.array(region_coord)
58
+ width = image.width
59
+ height = image.height
60
+ image = image.resize((224, 224))
61
+ gt_box = gt_box / np.array([width, height, width, height]) * 224
62
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
63
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|>"]
64
+
65
+ encodings = tokenizer(
66
+ prompt,
67
+ padding="longest",
68
+ truncation=True,
69
+ return_tensors="pt",
70
+ max_length=2000,
71
+ )
72
+ input_ids = encodings["input_ids"]
73
+ attention_mask = encodings["attention_mask"]
74
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
75
+ image_start_index_list = [[x] for x in image_start_index_list]
76
+ image_nums = [1] * len(input_ids)
77
+ batch_images = batch_images.cuda()
78
+ input_ids = input_ids.cuda()
79
+ attention_mask = attention_mask.cuda()
80
+ added_bbox_list = [(torch.tensor(gt_box).cuda() / 224).clamp(0, 0.99).unsqueeze(0)]
81
+
82
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
83
+ outputs = model.generate(
84
+ batch_images,
85
+ input_ids,
86
+ attention_mask=attention_mask,
87
+ max_new_tokens=25,
88
+ min_length=5,
89
+ num_beams=8,
90
+ length_penalty=0,
91
+ image_start_index_list=image_start_index_list,
92
+ image_nums=image_nums,
93
+ added_bbox_list=added_bbox_list,
94
+ )
95
+ outputs = outputs[:, len(input_ids[0]) :]
96
+ new_prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower()
97
+ this_tot += 1
98
+ if rank == 0 and this_tot % 10 == 0:
99
+ for i in range(1):
100
+ tqdm.write(f"answer: {text}\nmodel output: {new_prediction}")
101
+ predictions.append(
102
+ {"image_id": image_id, "caption": new_prediction}
103
+ )
104
+ results_path = f"reg_{lang_encoder_name}_{rank}_{id}.json"
105
+ json.dump(predictions, open(results_path, "w"))
106
+ print("save to", results_path)
107
+ del predictions
108
+ time.sleep(5)
109
+ if world_size > 1:
110
+ torch.distributed.barrier()
111
+ if rank == 0:
112
+ print(f"evaluate on rank {rank}. world size is {world_size}")
113
+ predictions = []
114
+ for rank_i in range(world_size):
115
+ part_results_path = f"reg_{lang_encoder_name}_{rank_i}_{id}.json"
116
+ print("load", part_results_path)
117
+ part_data = json.load(open(part_results_path))
118
+ predictions.extend(part_data)
119
+ os.remove(part_results_path)
120
+ print("num:", len(predictions))
121
+ results_path = f"reg_{lang_encoder_name}_{id}_result.json"
122
+ json.dump(predictions, open(results_path, "w"), indent=2)
123
+
124
+ metrics = compute_cider(
125
+ result_path=results_path,
126
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_label.json",
127
+ )
128
+ os.makedirs("eval_results", exist_ok=True)
129
+ cider = metrics["CIDEr"]
130
+ print("cider", cider)
131
+ with open(os.path.join("eval_results", f"reg_{model.expr_name}_{model.step_num}_{int(time.time())}_{cider}"), "w") as f:
132
+ f.write(json.dumps(predictions, indent=2))
133
+ # delete the temporary file
134
+ os.remove(results_path)
135
+ return cider
136
+
137
+
138
+ if __name__ == "__main__":
139
+ anno = json.load(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json"))
140
+ import pdb; pdb.set_trace()
141
+ print(anno.keys())
multimodal/build/lib/open_flamingo/eval/task/utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import torch
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import itertools
6
+ nlp = spacy.load('en_core_web_md')
7
+
8
+
9
+ def get_iou(box1, box2):
10
+ # box1 and box2 should be in the format [x1, y1, x2, y2]
11
+ intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
12
+ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
13
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
14
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
15
+ union = area_box1 + area_box2 - intersection
16
+ iou = intersection / union if union > 0 else 0
17
+ return iou
18
+
19
+
20
+ # def find_root(token):
21
+ # if token.pos_ == "VERB":
22
+ # return token
23
+ # while token.dep_ not in ["pobj", "nsubj", "ROOT", "npadvmod", "dobj", "det", "prep", "punct", "cc", "conj", "acl", "dep", "appos", "relcl", "advmod", "nmod", "attr"]:
24
+ # token = token.head
25
+ # return token
26
+
27
+
28
+ def find_root(token):
29
+ if token.pos_ == "VERB":
30
+ return token
31
+ while token.dep_ in ["compound", "amod"]:
32
+ token = token.head
33
+ return token
34
+
35
+ def get_object_from_text(text, verbose=False):
36
+ if len(text.split(" ")) == 3:
37
+ text = text.split(" ")
38
+ return [text[0], text[-1]]
39
+ doc = nlp(text)
40
+ if verbose:
41
+ for TT in doc:
42
+ print(TT.text, TT.pos_, TT.dep_, TT.head)
43
+ roots = set()
44
+ for i, token in enumerate(doc):
45
+ roots.add(find_root(token))
46
+ exprs = []
47
+ roots = sorted(list(roots), key=lambda token: token.idx)
48
+ first_nsubj = True
49
+ if verbose:
50
+ print(roots)
51
+ for root in roots:
52
+ if root.pos_ not in ["NOUN", "PROPN"]:
53
+ continue
54
+ if root.dep_ not in ["pobj", "nsubj"]:
55
+ continue
56
+ if not first_nsubj and root.dep_ in ["nsubj"]:
57
+ continue
58
+ exprs.append([])
59
+ for token in doc:
60
+ if find_root(token) == root:
61
+ exprs[-1].append(token.text)
62
+ exprs[-1] = " ".join(exprs[-1]).replace(" '", "'")
63
+ if exprs[-1] not in text:
64
+ if verbose:
65
+ print("not in text error:", exprs[-1], "#",text)
66
+ # for TT in doc:
67
+ # print(TT.text, TT.pos_, TT.dep_, TT.head)
68
+ # import pdb; pdb.set_trace()
69
+ exprs.pop()
70
+ if first_nsubj and root.dep_ in ["nsubj"]:
71
+ first_nsubj = False
72
+ if len(exprs) <= 1:
73
+ if verbose:
74
+ print("not enough exprs error:", exprs, "#",text)
75
+ return []
76
+ return exprs
77
+
78
+ def is_correct(input_ids, logits, tokenizer, object: str, topk=5, N=10):
79
+ answer_id = torch.tensor(tokenizer(f" {object}", add_special_tokens=False)["input_ids"]).to(input_ids.device)
80
+ answer_begin_idx = (input_ids == answer_id[0]).nonzero()
81
+ answer_idx = None
82
+ for (batch_idx, IDX) in answer_begin_idx:
83
+ try:
84
+ if (input_ids[batch_idx, IDX:IDX+len(answer_id)] == answer_id).all():
85
+ answer_idx = list(range(IDX-1, IDX+len(answer_id)-1))
86
+ except:
87
+ pass
88
+ if answer_idx is None:
89
+ return np.inf, False, False
90
+ res = logits[0, answer_idx].softmax(-1).sort(descending=True)
91
+ values = res.values
92
+ indices = res.indices
93
+ chosen_ids = list(itertools.product(*([list(range(N))]*len(answer_idx))))
94
+ probs = []
95
+ for ids in chosen_ids:
96
+ prob = 1.0
97
+ for i, id in enumerate(ids):
98
+ prob *= values[i, id]
99
+ probs.append((prob.item(), ids))
100
+ probs.sort(reverse=True)
101
+ answer_pos = tuple([id_array.tolist().index(idx) for id_array, idx in zip(indices, answer_id)])
102
+ ranking = [p[1] for p in probs]
103
+ # if len(answer_idx) > 1:
104
+ # import pdb; pdb.set_trace()
105
+ try:
106
+ r = ranking.index(answer_pos)
107
+ return r, r < 1, r < 5
108
+ except:
109
+ return np.inf, False, False
110
+
111
+ def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
112
+ assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
113
+ encodings = tokenizer(
114
+ prompt,
115
+ padding="longest",
116
+ truncation=True,
117
+ return_tensors="pt",
118
+ max_length=2000,
119
+ )
120
+ input_ids = encodings["input_ids"]
121
+ attention_mask = encodings["attention_mask"]
122
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
123
+ image_start_index_list = [[x] for x in image_start_index_list]
124
+ image_nums = [1] * len(input_ids)
125
+ vision_x = batch_images.cuda()
126
+ lang_x = input_ids.cuda()
127
+ attention_mask = attention_mask.cuda()
128
+
129
+ model.debug_id = 0
130
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
131
+ outputs = model(
132
+ vision_x=vision_x,
133
+ lang_x=lang_x,
134
+ attention_mask=attention_mask,
135
+ labels=None,
136
+ image_nums=image_nums,
137
+ image_start_index_list=image_start_index_list,
138
+ added_bbox_list=visual_box_list,
139
+ add_box=visual_box_list is not None,
140
+ relations=None,
141
+ debug_mode=False,
142
+ )
143
+ boxes = outputs["boxes"]
144
+ scores = outputs["scores"]
145
+ if debug:
146
+ import pdb; pdb.set_trace()
147
+ if return_all:
148
+ return boxes, scores
149
+ if len(scores) == 0:
150
+ return None, None
151
+ else:
152
+ return boxes[scores.argmax()], scores.max()
153
+
154
+
155
+ def _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=False, objects=None):
156
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
157
+ if objects is None:
158
+ objects = get_object_from_text(text)
159
+ if len(objects) == 0:
160
+ return None, None, None
161
+ if debug:
162
+ tqdm.write(text)
163
+ tqdm.write(f"{objects}")
164
+ first_idx = text.find(objects[0])
165
+ if first_idx == 0:
166
+ first_text = f"<|#object#|>{objects[0]}<|#endofobject#|><|#visual#|>"
167
+ else:
168
+ first_text = text[:first_idx-1] + f"<|#object#|> {objects[0]}<|#endofobject#|><|#visual#|>"
169
+
170
+ if debug:
171
+ tqdm.write(first_text)
172
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
173
+ # import pdb; pdb.set_trace()
174
+ # print("do first get_bbox |", first_text)
175
+ first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
176
+ if not model.valid and debug:
177
+ import pdb; pdb.set_trace()
178
+ if first_box is not None:
179
+ added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
180
+ text = first_text + "<|#box#|><|#endofobject#|>" + text[first_idx+len(objects[0]):]
181
+ else:
182
+ added_bbox_list = []
183
+
184
+ final_ranks = []
185
+ is_top1_list = []
186
+ is_top5_list = []
187
+ for kk, object in enumerate(objects):
188
+ if kk == 0:
189
+ continue
190
+ idx = text.find(objects[0])
191
+ for t_i, temp in enumerate(objects[1:kk+1]):
192
+ # t_i is actually the previous one. This is not a bug
193
+ idx = text.find(temp, idx + len(objects[t_i]))
194
+ while idx+len(temp) != len(text) and (text[idx-1] == "#" or text[idx+len(temp)] == "#"):
195
+ # in case temp is box or object or visual or something like that
196
+ idx = text.find(temp, idx + len(temp))
197
+ this_text = text[:idx-1] + "<|#object#|><|#previsual#|>"
198
+ # if this_text == "<|#object#|><|#previsual#|>":
199
+ # import pdb; pdb.set_trace()
200
+ if debug:
201
+ tqdm.write(this_text)
202
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
203
+ # import pdb; pdb.set_trace()
204
+ # print("do pre get_bbox |", this_text)
205
+ pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
206
+ prebox_token_id, return_all=True)
207
+ if not model.valid and debug:
208
+ import pdb; pdb.set_trace()
209
+ logits_list = []
210
+ # pre_boxes = [pre_boxes[0]]
211
+ # pre_scores = [pre_scores[0]]
212
+ this_text = this_text + f"<|#prebox#|><|#object#|> {object}<|#endofobject#|>"
213
+ for pre_box, pre_score in zip(pre_boxes, pre_scores):
214
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
215
+ encodings = tokenizer(
216
+ prompt,
217
+ padding="longest",
218
+ truncation=True,
219
+ return_tensors="pt",
220
+ max_length=512,
221
+ )
222
+ input_ids = encodings["input_ids"]
223
+ attention_mask = encodings["attention_mask"]
224
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
225
+ image_start_index_list = [[x] for x in image_start_index_list]
226
+ image_nums = [1] * len(input_ids)
227
+ vision_x = batch_images.cuda()
228
+ lang_x = input_ids.cuda()
229
+ attention_mask = attention_mask.cuda()
230
+ this_added_bbox_list = added_bbox_list + [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
231
+
232
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
233
+ outputs = model(
234
+ vision_x=vision_x,
235
+ lang_x=lang_x,
236
+ attention_mask=attention_mask,
237
+ image_nums=image_nums,
238
+ image_start_index_list=image_start_index_list,
239
+ added_bbox_list=this_added_bbox_list,
240
+ add_box=this_added_bbox_list is not None and len(this_added_bbox_list) != 0,
241
+ relations=None,
242
+ )
243
+ if not model.valid and debug:
244
+ import pdb; pdb.set_trace()
245
+ logits_list.append([pre_score, outputs.logits])
246
+ if debug:
247
+ answer_start_idx = (lang_x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1]
248
+ logits = outputs["logits"][0, answer_start_idx:]
249
+ tqdm.write(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10]))
250
+ # if debug:
251
+ # image.save("Atest.png")
252
+ # open_cv_image = np.array(image)
253
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
254
+ # if first_box is not None:
255
+ # open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
256
+ # if pre_box is not None:
257
+ # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
258
+ # cv2.imwrite(f"Atest.png", open_cv_image)
259
+ # import pdb; pdb.set_trace()
260
+ pre_scores = np.array([x[0] for x in logits_list])
261
+ final_probs = 0.0
262
+ for score, (_, logits) in zip(pre_scores, logits_list):
263
+ final_probs += score * logits.softmax(-1)
264
+ assert input_ids.shape[:2] == final_probs.shape[:2]
265
+ _rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, object, topk=5)
266
+ final_ranks.append(_rank)
267
+ is_top1_list.append(is_top1)
268
+ is_top5_list.append(is_top5)
269
+ this_text = text[:idx-1] + f"<|#object#|> {object}<|#endofobject#|><|#visual#|>"
270
+ if debug:
271
+ tqdm.write(this_text)
272
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
273
+ # print("do this get_bbox |", this_text)
274
+ this_box, this_score = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
275
+ if not model.valid and debug:
276
+ import pdb; pdb.set_trace()
277
+ if this_box is not None:
278
+ added_bbox_list += [torch.tensor(this_box).unsqueeze(0).cuda() / 224]
279
+ text = this_text + "<|#box#|><|#endofobject#|>" + text[idx+len(object):]
280
+ return final_ranks, is_top1_list, is_top5_list
281
+
282
+
283
+
284
+
285
+ if __name__ == "__main__":
286
+ # print(get_object_from_text("there is a cookie. there is a bear. white orio cookie is next to the teddy bear. car runs on the traffic road. there is a tree.", verbose=False))
287
+ print(get_object_from_text("President speaks to an American at a business office",verbose=True))
multimodal/build/lib/open_flamingo/eval/task/vl_checklist.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import webdataset as wds
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import os
8
+ import time
9
+ import cv2
10
+ import random
11
+ from open_flamingo.eval.task.utils import (
12
+ get_object_from_text,
13
+ is_correct,
14
+ _eval_text_image,
15
+ )
16
+ DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/instruct/vl_checklist/Relation/000000.tar"
17
+
18
+ def evaluate_vlc(
19
+ model,
20
+ tokenizer,
21
+ image_processor,
22
+ vis_embed_size=None,
23
+ rank=0,
24
+ world_size=1,
25
+ id=0,
26
+ subset=True,
27
+ subset_size="5k",
28
+ debug=False,
29
+ ):
30
+ dataset_name = "vlc"
31
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
32
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
33
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
34
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
35
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
36
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
37
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
38
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
39
+ model.eval().cuda()
40
+ total = 0
41
+ n_top1 = 0
42
+ n_top5 = 0
43
+ n_top10 = 0
44
+ filename = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data.json" if not subset else f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data_subset_{subset_size}.json"
45
+ dataset = json.load(open(filename))
46
+
47
+ pbar = tqdm(dataset, disable=(rank != 0))
48
+ for ii, sample in enumerate(pbar):
49
+ if ii % world_size != rank:
50
+ continue
51
+ text, image_path = sample
52
+ image = Image.open(image_path).convert("RGB")
53
+ image = image.resize((224, 224))
54
+ final_ranks, is_top1_list, is_top5_list = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
55
+ if final_ranks is None:
56
+ continue
57
+ n_top1 += int(sum(is_top1_list))
58
+ n_top5 += int(sum(is_top5_list))
59
+ n_top10 += int((np.array(final_ranks) < 10).sum())
60
+ total += len(final_ranks)
61
+ if debug:
62
+ tqdm.write("="*80)
63
+ pbar.set_description(f"acc@top1: {n_top1 / total:.4f} | acc@top5: {n_top5 / total:.4f} | acc@top10: {n_top10 / total:.4f} | {final_ranks} |{text}")
64
+
65
+
66
+ with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
67
+ f.write(json.dumps([total, n_top1, n_top5, n_top10]))
68
+ if world_size > 1:
69
+ torch.distributed.barrier()
70
+ if rank == 0:
71
+ total = 0
72
+ n_top1 = 0
73
+ n_top5 = 0
74
+ n_top10 = 0
75
+ print(f"evaluate on rank {rank}. world size is {world_size}")
76
+ for rank_i in range(world_size):
77
+ [total_part, n_top1_part, n_top5_part, n_top10_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
78
+ os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
79
+ total += total_part
80
+ n_top1 += n_top1_part
81
+ n_top5 += n_top5_part
82
+ n_top10 += n_top10_part
83
+ print("acc@top1:", n_top1 / total, "acc@top5:", n_top5 / total, "acc@top10:", n_top10 / total, "total:", total)
84
+ with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{n_top1 / total}_{n_top5 / total}_{n_top10 / total}_{total}"), "w") as f:
85
+ pass
86
+ else:
87
+ score = 0.0
88
+ if world_size > 1:
89
+ torch.distributed.barrier()
90
+ return score
91
+
92
+
93
+ if __name__ == "__main__":
94
+ dataset = wds.WebDataset(DATASET_ROOT).decode().shuffle(100000).to_tuple("data.pyd", "dataset.txt", "image_path.txt")
95
+ labels = set()
96
+ texts = []
97
+ data_pair = []
98
+ if not os.path.exists("vlc_data.json"):
99
+ for sample in tqdm(dataset):
100
+ data, dataset_name, image_path = sample
101
+ text = data[-1]["POS"][0]
102
+ texts.append(text)
103
+ data_pair.append([text, image_path])
104
+ json.dump(data_pair, open("vlc_data.json", "w"), indent=1)
105
+ else:
106
+ print("data exists")
107
+ data_pair = json.load(open("vlc_data.json"))
108
+ for text, image_path in data_pair:
109
+ texts.append(text)
110
+
111
+
112
+
113
+ print(get_object_from_text("crow attacks the dove"))
multimodal/build/lib/open_flamingo/eval/vqa_metric.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import json
4
+ import os
5
+ import random
6
+ import re
7
+ import sys
8
+
9
+ # Interface for accessing the VQA dataset.
10
+
11
+ # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
12
+ # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
13
+
14
+ # The following functions are defined:
15
+ # VQA - VQA class that loads VQA annotation file and prepares data structures.
16
+ # getQuesIds - Get question ids that satisfy given filter conditions.
17
+ # getImgIds - Get image ids that satisfy given filter conditions.
18
+ # loadQA - Load questions and answers with the specified question ids.
19
+ # showQA - Display the specified questions and answers.
20
+ # loadRes - Load result file and create result object.
21
+
22
+ # Help on each function can be accessed by: "help(COCO.function)"
23
+
24
+
25
+ class VQA:
26
+ def __init__(self, annotation_file=None, question_file=None):
27
+ """
28
+ Constructor of VQA helper class for reading and visualizing questions and answers.
29
+ :param annotation_file (str): location of VQA annotation file
30
+ :return:
31
+ """
32
+ # load dataset
33
+ self.dataset = {}
34
+ self.questions = {}
35
+ self.qa = {}
36
+ self.qqa = {}
37
+ self.imgToQA = {}
38
+ if not annotation_file == None and not question_file == None:
39
+ print("loading VQA annotations and questions into memory...")
40
+ time_t = datetime.datetime.utcnow()
41
+ dataset = json.load(open(annotation_file, "r"))
42
+ questions = json.load(open(question_file, "r"))
43
+ print(datetime.datetime.utcnow() - time_t)
44
+ self.dataset = dataset
45
+ self.questions = questions
46
+ self.createIndex()
47
+
48
+ def createIndex(self):
49
+ # create index
50
+ print("creating index...")
51
+ imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
52
+ qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
53
+ qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
54
+ for ann in self.dataset["annotations"]:
55
+ imgToQA[ann["image_id"]] += [ann]
56
+ qa[ann["question_id"]] = ann
57
+ for ques in self.questions["questions"]:
58
+ qqa[ques["question_id"]] = ques
59
+ print("index created!")
60
+
61
+ # create class members
62
+ self.qa = qa
63
+ self.qqa = qqa
64
+ self.imgToQA = imgToQA
65
+
66
+ def info(self):
67
+ """
68
+ Print information about the VQA annotation file.
69
+ :return:
70
+ """
71
+ for key, value in self.dataset["info"].items():
72
+ print("%s: %s" % (key, value))
73
+
74
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
75
+ """
76
+ Get question ids that satisfy given filter conditions. default skips that filter
77
+ :param imgIds (int array) : get question ids for given imgs
78
+ quesTypes (str array) : get question ids for given question types
79
+ ansTypes (str array) : get question ids for given answer types
80
+ :return: ids (int array) : integer array of question ids
81
+ """
82
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
83
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
84
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
85
+
86
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
87
+ anns = self.dataset["annotations"]
88
+ else:
89
+ if not len(imgIds) == 0:
90
+ anns = sum(
91
+ [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
92
+ [],
93
+ )
94
+ else:
95
+ anns = self.dataset["annotations"]
96
+ anns = (
97
+ anns
98
+ if len(quesTypes) == 0
99
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
100
+ )
101
+ anns = (
102
+ anns
103
+ if len(ansTypes) == 0
104
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
105
+ )
106
+ ids = [ann["question_id"] for ann in anns]
107
+ return ids
108
+
109
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
110
+ """
111
+ Get image ids that satisfy given filter conditions. default skips that filter
112
+ :param quesIds (int array) : get image ids for given question ids
113
+ quesTypes (str array) : get image ids for given question types
114
+ ansTypes (str array) : get image ids for given answer types
115
+ :return: ids (int array) : integer array of image ids
116
+ """
117
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
118
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
119
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
120
+
121
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
122
+ anns = self.dataset["annotations"]
123
+ else:
124
+ if not len(quesIds) == 0:
125
+ anns = sum(
126
+ [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
127
+ )
128
+ else:
129
+ anns = self.dataset["annotations"]
130
+ anns = (
131
+ anns
132
+ if len(quesTypes) == 0
133
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
134
+ )
135
+ anns = (
136
+ anns
137
+ if len(ansTypes) == 0
138
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
139
+ )
140
+ ids = [ann["image_id"] for ann in anns]
141
+ return ids
142
+
143
+ def loadQA(self, ids=[]):
144
+ """
145
+ Load questions and answers with the specified question ids.
146
+ :param ids (int array) : integer ids specifying question ids
147
+ :return: qa (object array) : loaded qa objects
148
+ """
149
+ if type(ids) == list:
150
+ return [self.qa[id] for id in ids]
151
+ elif type(ids) == int:
152
+ return [self.qa[ids]]
153
+
154
+ def showQA(self, anns):
155
+ """
156
+ Display the specified annotations.
157
+ :param anns (array of object): annotations to display
158
+ :return: None
159
+ """
160
+ if len(anns) == 0:
161
+ return 0
162
+ for ann in anns:
163
+ quesId = ann["question_id"]
164
+ print("Question: %s" % (self.qqa[quesId]["question"]))
165
+ for ans in ann["answers"]:
166
+ print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
167
+
168
+ def loadRes(self, resFile, quesFile):
169
+ """
170
+ Load result file and return a result object.
171
+ :param resFile (str) : file name of result file
172
+ :return: res (obj) : result api object
173
+ """
174
+ res = VQA()
175
+ res.questions = json.load(open(quesFile))
176
+ res.dataset["info"] = copy.deepcopy(self.questions["info"])
177
+ res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
178
+ res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
179
+ res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
180
+ res.dataset["license"] = copy.deepcopy(self.questions["license"])
181
+
182
+ print("Loading and preparing results... ")
183
+ time_t = datetime.datetime.utcnow()
184
+ anns = json.load(open(resFile))
185
+ assert type(anns) == list, "results is not an array of objects"
186
+ annsQuesIds = [ann["question_id"] for ann in anns]
187
+ # print set of question ids that do not have corresponding annotations
188
+
189
+ # assert set(annsQuesIds) == set(self.getQuesIds()), \
190
+ # 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
191
+ for ann in anns:
192
+ quesId = ann["question_id"]
193
+ if res.dataset["task_type"] == "Multiple Choice":
194
+ assert (
195
+ ann["answer"] in self.qqa[quesId]["multiple_choices"]
196
+ ), "predicted answer is not one of the multiple choices"
197
+ qaAnn = self.qa[quesId]
198
+ ann["image_id"] = qaAnn["image_id"]
199
+ ann["question_type"] = qaAnn["question_type"]
200
+ ann["answer_type"] = qaAnn["answer_type"]
201
+ print(
202
+ "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
203
+ )
204
+
205
+ res.dataset["annotations"] = anns
206
+ res.createIndex()
207
+ return res
208
+
209
+
210
+ class VQAEval:
211
+ def __init__(self, vqa=None, vqaRes=None, n=2):
212
+ self.n = n
213
+ self.accuracy = {}
214
+ self.evalQA = {}
215
+ self.evalQuesType = {}
216
+ self.evalAnsType = {}
217
+ self.vqa = vqa
218
+ self.vqaRes = vqaRes
219
+ if vqaRes is not None:
220
+ self.params = {"question_id": vqaRes.getQuesIds()}
221
+ self.contractions = {
222
+ "aint": "ain't",
223
+ "arent": "aren't",
224
+ "cant": "can't",
225
+ "couldve": "could've",
226
+ "couldnt": "couldn't",
227
+ "couldn'tve": "couldn't've",
228
+ "couldnt've": "couldn't've",
229
+ "didnt": "didn't",
230
+ "doesnt": "doesn't",
231
+ "dont": "don't",
232
+ "hadnt": "hadn't",
233
+ "hadnt've": "hadn't've",
234
+ "hadn'tve": "hadn't've",
235
+ "hasnt": "hasn't",
236
+ "havent": "haven't",
237
+ "hed": "he'd",
238
+ "hed've": "he'd've",
239
+ "he'dve": "he'd've",
240
+ "hes": "he's",
241
+ "howd": "how'd",
242
+ "howll": "how'll",
243
+ "hows": "how's",
244
+ "Id've": "I'd've",
245
+ "I'dve": "I'd've",
246
+ "Im": "I'm",
247
+ "Ive": "I've",
248
+ "isnt": "isn't",
249
+ "itd": "it'd",
250
+ "itd've": "it'd've",
251
+ "it'dve": "it'd've",
252
+ "itll": "it'll",
253
+ "let's": "let's",
254
+ "maam": "ma'am",
255
+ "mightnt": "mightn't",
256
+ "mightnt've": "mightn't've",
257
+ "mightn'tve": "mightn't've",
258
+ "mightve": "might've",
259
+ "mustnt": "mustn't",
260
+ "mustve": "must've",
261
+ "neednt": "needn't",
262
+ "notve": "not've",
263
+ "oclock": "o'clock",
264
+ "oughtnt": "oughtn't",
265
+ "ow's'at": "'ow's'at",
266
+ "'ows'at": "'ow's'at",
267
+ "'ow'sat": "'ow's'at",
268
+ "shant": "shan't",
269
+ "shed've": "she'd've",
270
+ "she'dve": "she'd've",
271
+ "she's": "she's",
272
+ "shouldve": "should've",
273
+ "shouldnt": "shouldn't",
274
+ "shouldnt've": "shouldn't've",
275
+ "shouldn'tve": "shouldn't've",
276
+ "somebody'd": "somebodyd",
277
+ "somebodyd've": "somebody'd've",
278
+ "somebody'dve": "somebody'd've",
279
+ "somebodyll": "somebody'll",
280
+ "somebodys": "somebody's",
281
+ "someoned": "someone'd",
282
+ "someoned've": "someone'd've",
283
+ "someone'dve": "someone'd've",
284
+ "someonell": "someone'll",
285
+ "someones": "someone's",
286
+ "somethingd": "something'd",
287
+ "somethingd've": "something'd've",
288
+ "something'dve": "something'd've",
289
+ "somethingll": "something'll",
290
+ "thats": "that's",
291
+ "thered": "there'd",
292
+ "thered've": "there'd've",
293
+ "there'dve": "there'd've",
294
+ "therere": "there're",
295
+ "theres": "there's",
296
+ "theyd": "they'd",
297
+ "theyd've": "they'd've",
298
+ "they'dve": "they'd've",
299
+ "theyll": "they'll",
300
+ "theyre": "they're",
301
+ "theyve": "they've",
302
+ "twas": "'twas",
303
+ "wasnt": "wasn't",
304
+ "wed've": "we'd've",
305
+ "we'dve": "we'd've",
306
+ "weve": "we've",
307
+ "werent": "weren't",
308
+ "whatll": "what'll",
309
+ "whatre": "what're",
310
+ "whats": "what's",
311
+ "whatve": "what've",
312
+ "whens": "when's",
313
+ "whered": "where'd",
314
+ "wheres": "where's",
315
+ "whereve": "where've",
316
+ "whod": "who'd",
317
+ "whod've": "who'd've",
318
+ "who'dve": "who'd've",
319
+ "wholl": "who'll",
320
+ "whos": "who's",
321
+ "whove": "who've",
322
+ "whyll": "why'll",
323
+ "whyre": "why're",
324
+ "whys": "why's",
325
+ "wont": "won't",
326
+ "wouldve": "would've",
327
+ "wouldnt": "wouldn't",
328
+ "wouldnt've": "wouldn't've",
329
+ "wouldn'tve": "wouldn't've",
330
+ "yall": "y'all",
331
+ "yall'll": "y'all'll",
332
+ "y'allll": "y'all'll",
333
+ "yall'd've": "y'all'd've",
334
+ "y'alld've": "y'all'd've",
335
+ "y'all'dve": "y'all'd've",
336
+ "youd": "you'd",
337
+ "youd've": "you'd've",
338
+ "you'dve": "you'd've",
339
+ "youll": "you'll",
340
+ "youre": "you're",
341
+ "youve": "you've",
342
+ }
343
+ self.manualMap = {
344
+ "none": "0",
345
+ "zero": "0",
346
+ "one": "1",
347
+ "two": "2",
348
+ "three": "3",
349
+ "four": "4",
350
+ "five": "5",
351
+ "six": "6",
352
+ "seven": "7",
353
+ "eight": "8",
354
+ "nine": "9",
355
+ "ten": "10",
356
+ }
357
+ self.articles = ["a", "an", "the"]
358
+
359
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
360
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
361
+ self.punct = [
362
+ ";",
363
+ r"/",
364
+ "[",
365
+ "]",
366
+ '"',
367
+ "{",
368
+ "}",
369
+ "(",
370
+ ")",
371
+ "=",
372
+ "+",
373
+ "\\",
374
+ "_",
375
+ "-",
376
+ ">",
377
+ "<",
378
+ "@",
379
+ "`",
380
+ ",",
381
+ "?",
382
+ "!",
383
+ ]
384
+
385
+ def evaluate(self, quesIds=None):
386
+ if quesIds == None:
387
+ quesIds = [quesId for quesId in self.params["question_id"]]
388
+ gts = {}
389
+ res = {}
390
+ for quesId in quesIds:
391
+ gts[quesId] = self.vqa.qa[quesId]
392
+ res[quesId] = self.vqaRes.qa[quesId]
393
+
394
+ # =================================================
395
+ # Compute accuracy
396
+ # =================================================
397
+ accQA = []
398
+ accQuesType = {}
399
+ accAnsType = {}
400
+ print("computing accuracy")
401
+ step = 0
402
+ for quesId in quesIds:
403
+ for ansDic in gts[quesId]["answers"]:
404
+ ansDic["answer"] = ansDic["answer"].replace("\n", " ")
405
+ ansDic["answer"] = ansDic["answer"].replace("\t", " ")
406
+ ansDic["answer"] = ansDic["answer"].strip()
407
+ resAns = res[quesId]["answer"]
408
+ resAns = resAns.replace("\n", " ")
409
+ resAns = resAns.replace("\t", " ")
410
+ resAns = resAns.strip()
411
+ gtAcc = []
412
+ gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
413
+
414
+ if len(set(gtAnswers)) > 1:
415
+ for ansDic in gts[quesId]["answers"]:
416
+ ansDic["answer"] = self.processPunctuation(ansDic["answer"])
417
+ ansDic["answer"] = self.processDigitArticle(ansDic["answer"])
418
+ resAns = self.processPunctuation(resAns)
419
+ resAns = self.processDigitArticle(resAns)
420
+
421
+ for gtAnsDatum in gts[quesId]["answers"]:
422
+ otherGTAns = [
423
+ item for item in gts[quesId]["answers"] if item != gtAnsDatum
424
+ ]
425
+ matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
426
+ acc = min(1, float(len(matchingAns)) / 3)
427
+ gtAcc.append(acc)
428
+ quesType = gts[quesId]["question_type"]
429
+ ansType = gts[quesId]["answer_type"]
430
+ avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
431
+ accQA.append(avgGTAcc)
432
+ if quesType not in accQuesType:
433
+ accQuesType[quesType] = []
434
+ accQuesType[quesType].append(avgGTAcc)
435
+ if ansType not in accAnsType:
436
+ accAnsType[ansType] = []
437
+ accAnsType[ansType].append(avgGTAcc)
438
+ self.setEvalQA(quesId, avgGTAcc)
439
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
440
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
441
+ if step % 100 == 0:
442
+ self.updateProgress(step / float(len(quesIds)))
443
+ step = step + 1
444
+
445
+ self.setAccuracy(accQA, accQuesType, accAnsType)
446
+ print("Done computing accuracy")
447
+
448
+ def processPunctuation(self, inText):
449
+ outText = inText
450
+ for p in self.punct:
451
+ if (p + " " in inText or " " + p in inText) or (
452
+ re.search(self.commaStrip, inText) != None
453
+ ):
454
+ outText = outText.replace(p, "")
455
+ else:
456
+ outText = outText.replace(p, " ")
457
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
458
+ return outText
459
+
460
+ def processDigitArticle(self, inText):
461
+ outText = []
462
+ tempText = inText.lower().split()
463
+ for word in tempText:
464
+ word = self.manualMap.setdefault(word, word)
465
+ if word not in self.articles:
466
+ outText.append(word)
467
+ else:
468
+ pass
469
+ for wordId, word in enumerate(outText):
470
+ if word in self.contractions:
471
+ outText[wordId] = self.contractions[word]
472
+ outText = " ".join(outText)
473
+ return outText
474
+
475
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
476
+ self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
477
+ self.accuracy["perQuestionType"] = {
478
+ quesType: round(
479
+ 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
480
+ self.n,
481
+ )
482
+ for quesType in accQuesType
483
+ }
484
+ self.accuracy["perAnswerType"] = {
485
+ ansType: round(
486
+ 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
487
+ )
488
+ for ansType in accAnsType
489
+ }
490
+
491
+ def setEvalQA(self, quesId, acc):
492
+ self.evalQA[quesId] = round(100 * acc, self.n)
493
+
494
+ def setEvalQuesType(self, quesId, quesType, acc):
495
+ if quesType not in self.evalQuesType:
496
+ self.evalQuesType[quesType] = {}
497
+ self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
498
+
499
+ def setEvalAnsType(self, quesId, ansType, acc):
500
+ if ansType not in self.evalAnsType:
501
+ self.evalAnsType[ansType] = {}
502
+ self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
503
+
504
+ def updateProgress(self, progress):
505
+ barLength = 20
506
+ status = ""
507
+ if isinstance(progress, int):
508
+ progress = float(progress)
509
+ if not isinstance(progress, float):
510
+ progress = 0
511
+ status = "error: progress var must be float\r\n"
512
+ if progress < 0:
513
+ progress = 0
514
+ status = "Halt...\r\n"
515
+ if progress >= 1:
516
+ progress = 1
517
+ status = "Done...\r\n"
518
+ block = int(round(barLength * progress))
519
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format(
520
+ "#" * block + "-" * (barLength - block), int(progress * 100), status
521
+ )
522
+ sys.stdout.write(text)
523
+ sys.stdout.flush()
524
+
525
+
526
+ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path, vqa_dataset):
527
+ """Compute the VQA accuracy metric.
528
+
529
+ Args:
530
+ predictions (List): list of predictions
531
+ ground_truth (List[List]): list of all possible ground truth answers
532
+
533
+ Returns:
534
+ float: VQA accuracy
535
+ """
536
+ # coding: utf-8
537
+ # dataDir = data_dir
538
+
539
+ # set up file names and paths
540
+ # versionType = 'v2_' # this should be '' when using VQA v2.0 dataset
541
+ # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
542
+ # taskType = 'OpenEnded'
543
+ # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
544
+ # dataType = 'mscoco'
545
+ # dataSubType = 'train2014'
546
+ # annFile = '%s/%s%s_%s_annotations.json' % (
547
+ # dataDir, versionType, dataType, dataSubType)
548
+ # quesFile = '%s/%s%s_%s_%s_questions.json' % (
549
+ # dataDir, versionType, taskType, dataType, dataSubType)
550
+ # imgDir = '%s/%s/%s/' % (dataDir, dataType, dataSubType)
551
+ # resultType = res_file_name
552
+ # fileTypes = ['results', 'accuracy',
553
+ # 'evalQA', 'evalQuesType', 'evalAnsType']
554
+
555
+ # An example result json file has been provided in './Results' folder.
556
+
557
+ # [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/%s%s_%s_%s_%s_%s.json' % (dataDir, versionType, taskType, dataType, dataSubType,
558
+ # resultType, fileType) for fileType in fileTypes]
559
+
560
+ # create vqa object and vqaRes object
561
+ vqa = VQA(annotation_json_path, question_json_path)
562
+ vqaRes = vqa.loadRes(result_json_path, question_json_path)
563
+
564
+ # create vqaEval object by taking vqa and vqaRes
565
+ # n is precision of accuracy (number of places after decimal), default is 2
566
+ vqaEval = VQAEval(vqa, vqaRes, n=2)
567
+
568
+ # evaluate results
569
+ """
570
+ If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
571
+ By default it uses all the question ids in annotation file
572
+ """
573
+ vqaEval.evaluate()
574
+
575
+ return vqaEval.accuracy["overall"]
576
+
577
+
578
+ def postprocess_vqa_generation(predictions):
579
+ return re.split("Question|Answer", predictions, 1)[0]
580
+
581
+
582
+ def compute_gqa_accuracy(results):
583
+ acc = []
584
+ vqa_tool = VQAEval()
585
+
586
+ for res in results:
587
+ gt_ans = res["answers"]
588
+ pred = res["answer"]
589
+ pred = vqa_tool.processPunctuation(pred)
590
+ pred = vqa_tool.processDigitArticle(pred)
591
+ vqa_acc = 1 if pred == gt_ans else 0
592
+ acc.append(vqa_acc)
593
+ accuracy = sum(acc) / len(acc)
594
+ return accuracy
multimodal/build/lib/open_flamingo/src/__init__.py ADDED
File without changes
multimodal/build/lib/open_flamingo/src/attention.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import init
5
+
6
+
7
+
8
+ class SEAttention(nn.Module):
9
+
10
+ def __init__(self, channel=512,reduction=16):
11
+ super().__init__()
12
+ self.fc = nn.Sequential(
13
+ nn.Linear(channel, channel // reduction, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(channel // reduction, channel, bias=False),
16
+ nn.GELU(),
17
+ nn.Linear(channel, 1, bias=False),
18
+ nn.Sigmoid()
19
+ )
20
+
21
+
22
+ def init_weights(self):
23
+ for m in self.modules():
24
+ if isinstance(m, nn.Conv2d):
25
+ init.kaiming_normal_(m.weight, mode='fan_out')
26
+ if m.bias is not None:
27
+ init.constant_(m.bias, 0)
28
+ elif isinstance(m, nn.BatchNorm2d):
29
+ init.constant_(m.weight, 1)
30
+ init.constant_(m.bias, 0)
31
+ elif isinstance(m, nn.Linear):
32
+ init.normal_(m.weight, std=0.001)
33
+ if m.bias is not None:
34
+ init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.fc(x)
38
+ return x
39
+
40
+
41
+ if __name__ == '__main__':
42
+ input=torch.randn(50,512,7,7)
43
+ se = SEAttention(channel=512,reduction=8)
44
+ output=se(input)
45
+ print(output.shape)
multimodal/build/lib/open_flamingo/src/factory.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import open_clip
3
+ import torch
4
+
5
+ from .flamingo import Flamingo
6
+ from .flamingo_lm import FlamingoLMMixin
7
+ from .utils import extend_instance
8
+ import logging
9
+ import random
10
+ import time
11
+
12
+ def create_model_and_transforms(
13
+ clip_vision_encoder_path: str,
14
+ clip_vision_encoder_pretrained: str,
15
+ lang_encoder_path: str,
16
+ tokenizer_path: str,
17
+ use_local_files: bool = False,
18
+ decoder_layers_attr_name: str = None,
19
+ location_token_num: int = 1000,
20
+ checkpoint_activations: bool = False,
21
+ freeze_vision_encoder: bool = False,
22
+ lora: bool = False,
23
+ lora_r: int = 16,
24
+ fix_ffn: bool = False,
25
+ add_visual_token: bool = False,
26
+ add_box: bool = False,
27
+ add_pe: bool = False,
28
+ add_relation: bool = False,
29
+ use_format_v2: bool = False,
30
+ use_sam: str = None,
31
+ enhance_data: bool = False,
32
+ roi_align: bool = False,
33
+ roi_output_size: int = 4,
34
+ apply_mask: bool = False,
35
+ **flamingo_kwargs,
36
+ ):
37
+ """
38
+ Initialize a Flamingo model from a pretrained vision encoder and language encoder.
39
+ Appends special tokens to the tokenizer and freezes backbones.
40
+
41
+ Args:
42
+ clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
43
+ clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
44
+ lang_encoder_path (str): path to pretrained language encoder
45
+ tokenizer_path (str): path to pretrained tokenizer
46
+ cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
47
+ use_local_files (bool, optional): whether to use local files. Defaults to False.
48
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
49
+ Returns:
50
+ Flamingo: Flamingo model from pretrained vision and language encoders
51
+ Image processor: Pipeline to preprocess input images
52
+ Tokenizer: A tokenizer for the language model
53
+ """
54
+ if use_sam is None:
55
+ no_success = True
56
+ while no_success:
57
+ try:
58
+ vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
59
+ clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
60
+ )
61
+ no_success = False
62
+ except:
63
+ logging.info("retry creating vision_encoder")
64
+ time.sleep(random.random() * 5)
65
+
66
+ # set the vision encoder to output the visual features
67
+ vision_encoder.visual.output_tokens = True
68
+ # delete text encoder part
69
+ del vision_encoder.transformer
70
+ del vision_encoder.text_projection
71
+ del vision_encoder.token_embedding
72
+ del vision_encoder.ln_final
73
+ del vision_encoder.positional_embedding
74
+ del vision_encoder.logit_scale
75
+ vision_encoder.visual.proj = None
76
+ vision_encoder.visual.ln_post = torch.nn.Identity()
77
+ else:
78
+ from segment_anything import SamPredictor, sam_model_registry
79
+ assert use_sam == "vit_l"
80
+ sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth")
81
+ del sam.prompt_encoder
82
+ del sam.mask_decoder
83
+ sam.image_encoder.neck = torch.nn.Identity()
84
+ vision_encoder = sam.image_encoder
85
+ from open_clip.transform import image_transform
86
+ image_processor = image_transform(
87
+ 256,
88
+ is_train=False,
89
+ mean=(0.48145466, 0.4578275, 0.40821073),
90
+ std=(0.26862954, 0.26130258, 0.27577711),
91
+ )
92
+
93
+ text_tokenizer = AutoTokenizer.from_pretrained(
94
+ tokenizer_path, local_files_only=use_local_files
95
+ )
96
+ # add Flamingo special tokens to the tokenizer
97
+ additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"]
98
+ if add_visual_token:
99
+ additional_special_tokens += ["<|#visual#|>", "<|#object#|>"]
100
+ if add_box:
101
+ additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"]
102
+ if use_format_v2:
103
+ additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"]
104
+ if enhance_data:
105
+ additional_special_tokens += ["<|#NOTHING#|>"]
106
+ text_tokenizer.add_special_tokens(
107
+ {"additional_special_tokens": additional_special_tokens}
108
+ )
109
+ if text_tokenizer.pad_token is None:
110
+ # Issue: GPT models don't have a pad token, which we use to
111
+ # modify labels for the loss.
112
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
113
+
114
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
115
+ lang_encoder_path, local_files_only=use_local_files
116
+ )
117
+ extend_instance(lang_encoder, FlamingoLMMixin)
118
+
119
+ if decoder_layers_attr_name is None:
120
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
121
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
122
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
123
+ lang_encoder_name = lang_encoder.__class__.__name__.lower()
124
+ if checkpoint_activations:
125
+ from fairscale.nn.checkpoint import checkpoint_wrapper
126
+ if use_sam is None:
127
+ for i in range(len(vision_encoder.visual.transformer.resblocks)):
128
+ vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper(
129
+ vision_encoder.visual.transformer.resblocks[i],
130
+ offload_to_cpu=False,
131
+ )
132
+ else:
133
+ for i in range(len(vision_encoder.blocks)):
134
+ vision_encoder.blocks[i] = checkpoint_wrapper(
135
+ vision_encoder.blocks[i],
136
+ offload_to_cpu=False,
137
+ )
138
+ if "opt" in lang_encoder_name:
139
+ for i in range(len(lang_encoder.model.decoder.layers)):
140
+ lang_encoder.model.decoder.layers[i] = checkpoint_wrapper(
141
+ lang_encoder.model.decoder.layers[i],
142
+ offload_to_cpu=False,
143
+ )
144
+ elif "codegen" in lang_encoder_name:
145
+ for i in range(len(lang_encoder.transformer.h)):
146
+ lang_encoder.transformer.h[i] = checkpoint_wrapper(
147
+ lang_encoder.transformer.h[i],
148
+ offload_to_cpu=False,
149
+ )
150
+ elif "llama" in lang_encoder_name:
151
+ for i in range(len(lang_encoder.model.layers)):
152
+ lang_encoder.model.layers[i] = checkpoint_wrapper(
153
+ lang_encoder.model.layers[i],
154
+ offload_to_cpu=False,
155
+ )
156
+ elif "gptneo" in lang_encoder_name:
157
+ for i in range(len(lang_encoder.gpt_neox.layers)):
158
+ lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper(
159
+ lang_encoder.gpt_neox.layers[i],
160
+ offload_to_cpu=False,
161
+ )
162
+ else:
163
+ raise ValueError(f"unknown model {lang_encoder_name}")
164
+ if use_sam is None:
165
+ vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"]
166
+ image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"]
167
+ patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"]
168
+ else:
169
+ # SAM config
170
+ vis_dim = 1024
171
+ image_size = 256
172
+ patch_size = 16
173
+ assert image_size % patch_size == 0
174
+ vis_embed_size = (image_size // patch_size) ** 2
175
+
176
+ if lora:
177
+ from peft import LoraConfig, TaskType
178
+ from peft import get_peft_model
179
+ if "codegen" in lang_encoder_name:
180
+ lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
181
+ elif "opt" in lang_encoder_name:
182
+ lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
183
+ elif "llama" in lang_encoder_name:
184
+ lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
185
+ else:
186
+ raise NotImplementedError
187
+ lang_peft_config = LoraConfig(
188
+ task_type="CAUSAL_LM",
189
+ r=16, lora_alpha=16,
190
+ target_modules=lang_target_modules,
191
+ lora_dropout=0.05, bias="none",
192
+ )
193
+ lang_encoder = get_peft_model(lang_encoder, lang_peft_config)
194
+ lang_encoder.print_trainable_parameters()
195
+
196
+ if fix_ffn:
197
+ if "opt" in lang_encoder_name:
198
+ for i in range(len(lang_encoder.model.decoder.layers)):
199
+ lang_encoder.model.decoder.layers[i].requires_grad_(False)
200
+ lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True)
201
+ else:
202
+ raise NotImplementedError
203
+
204
+ lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size)
205
+ if hasattr(lang_encoder.config, "word_embed_proj_dim"):
206
+ hidden_state_dim = lang_encoder.config.word_embed_proj_dim
207
+ else:
208
+ hidden_state_dim = lang_encoder.config.hidden_size
209
+ model = Flamingo(
210
+ vision_encoder=vision_encoder,
211
+ lang_encoder=lang_encoder,
212
+ eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1],
213
+ media_token_id=text_tokenizer.encode("<|#image#|>")[-1],
214
+ image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1],
215
+ visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None,
216
+ previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None,
217
+ box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None,
218
+ prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None,
219
+ nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None,
220
+ endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1],
221
+ vis_dim=vis_dim,
222
+ vis_embed_size=vis_embed_size,
223
+ lang_dim=lang_dim,
224
+ image_size=image_size,
225
+ patch_size=patch_size,
226
+ hidden_state_dim=hidden_state_dim,
227
+ add_visual_token=add_visual_token,
228
+ add_pe=add_pe,
229
+ add_relation=add_relation,
230
+ use_format_v2=use_format_v2,
231
+ roi_align=roi_align,
232
+ roi_output_size=roi_output_size,
233
+ apply_mask=apply_mask,
234
+ **flamingo_kwargs,
235
+ )
236
+
237
+ if freeze_vision_encoder:
238
+ print("freeze vision encoder")
239
+ model.vision_encoder.requires_grad_(False)
240
+
241
+ print(
242
+ f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
243
+ )
244
+
245
+ return model, image_processor, text_tokenizer, vis_embed_size
246
+
247
+
248
+ def _infer_decoder_layers_attr_name(model):
249
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
250
+ if k.lower() in model.__class__.__name__.lower():
251
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
252
+
253
+ raise ValueError(
254
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
255
+ )
256
+
257
+
258
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
259
+ "opt": "model.decoder.layers",
260
+ # "gptneo": "transformer.h",
261
+ "gptj": "transformer.h",
262
+ "gpt-j": "transformer.h",
263
+ "pythia": "gpt_neox.layers",
264
+ "gptneox": "gpt_neox.layers",
265
+ "llama": "model.layers",
266
+ "llamaforcausallm": "model.layers",
267
+ "gpt2": "transformer.h",
268
+ "codegen": "transformer.h",
269
+ }
multimodal/build/lib/open_flamingo/src/flamingo.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from einops import rearrange
4
+ from torch import nn
5
+ from yolox.models.yolo_head import YOLOXHead
6
+ from yolox.utils.boxes import xyxy2cxcywh, cxcywh2xyxy
7
+ from yolox.utils.demo_utils import nms
8
+ # import matplotlib.pyplot as plt
9
+ # import seaborn as sns
10
+ import numpy as np
11
+ import logging
12
+ from open_flamingo.src.gcn import GCN
13
+ from transformers import LogitsProcessorList
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s %(message)s',
17
+ datefmt='%m/%d %I:%M:%S',
18
+ )
19
+
20
+
21
+ # class PositionEncodingModule(nn.Module):
22
+ # def __init__(self, dim, pos_dim=128):
23
+ # super().__init__()
24
+ # self.encode = nn.Sequential(
25
+ # nn.Linear(5, pos_dim // 2),
26
+ # nn.BatchNorm1d(pos_dim // 2),
27
+ # nn.GELU(),
28
+ # nn.Linear(pos_dim // 2, pos_dim),
29
+ # nn.BatchNorm1d(pos_dim),
30
+ # nn.GELU(),
31
+ # )
32
+ # self.merge = nn.Sequential(
33
+ # nn.Linear(dim + pos_dim, dim),
34
+ # nn.BatchNorm1d(dim),
35
+ # nn.GELU(),
36
+ # )
37
+
38
+ # def forward(self, x, box):
39
+ # box = self.encode(box)
40
+ # x = torch.cat([x, box], dim=-1)
41
+ # x = self.merge(x)
42
+ # return x
43
+
44
+
45
+ # class PositionEncodingModule(nn.Module):
46
+ # def __init__(self, dim):
47
+ # super().__init__()
48
+ # self.encode = nn.Sequential(
49
+ # nn.Linear(5, dim),
50
+ # nn.GELU(),
51
+ # )
52
+
53
+ # def forward(self, x, box):
54
+ # box = self.encode(box)
55
+ # x = x + box
56
+ # return x
57
+
58
+
59
+ # class PositionEncodingModule2(nn.Module):
60
+ # def __init__(self, dim):
61
+ # super().__init__()
62
+ # self.encode = nn.Sequential(
63
+ # nn.Linear(5 + dim, dim),
64
+ # nn.ELU(),
65
+ # )
66
+
67
+ # def forward(self, x, box):
68
+ # x = torch.cat([x, box], dim=-1)
69
+ # x = self.encode(x)
70
+ # return x
71
+
72
+
73
+ # class RelationHead(nn.Module):
74
+ # def __init__(self, dim):
75
+ # super().__init__()
76
+ # self.encode = nn.Sequential(
77
+ # nn.LayerNorm(dim),
78
+ # nn.Linear(dim, 128),
79
+ # nn.ELU(),
80
+ # )
81
+ # self.classifier = nn.Linear(256, 51)
82
+
83
+ # def forward(self, x1, x2):
84
+ # x1 = self.encode(x1)
85
+ # x2 = self.encode(x2)
86
+ # x = torch.cat([x1, x2], dim=-1)
87
+ # x = self.classifier(x)
88
+ # return x
89
+
90
+
91
+ class Flamingo(nn.Module):
92
+ def __init__(
93
+ self,
94
+ vision_encoder: nn.Module,
95
+ lang_encoder: nn.Module,
96
+ eoc_token_id: int,
97
+ media_token_id: int,
98
+ image_end_token_id: int,
99
+ visual_token_id: int,
100
+ previsual_token_id: int,
101
+ box_token_id: int,
102
+ prebox_token_id: int,
103
+ nothing_token_id: int,
104
+ endofobject_token_id: int,
105
+ vis_dim: int,
106
+ vis_embed_size: int,
107
+ lang_dim: int,
108
+ hidden_state_dim: int,
109
+ image_size: int,
110
+ patch_size: int,
111
+ use_media_placement_augmentation: bool = False,
112
+ add_visual_token: bool = False,
113
+ add_pe: bool = False,
114
+ add_relation: bool = False,
115
+ use_format_v2: bool = False,
116
+ roi_align: bool = False,
117
+ roi_output_size: int = 4,
118
+ apply_mask: bool = False,
119
+ ):
120
+ """
121
+ Args:
122
+ vision_encoder (nn.Module): HF CLIPModel
123
+ lang_encoder (nn.Module): HF causal language model
124
+ eoc_token_id (int): Token id for eos token
125
+ media_token_id (int): Token id for <|#image#|>
126
+ vis_dim (int): Dimension of the visual features.
127
+ Visual features are projected to match this shape along the last dimension.
128
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
129
+ use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
130
+ """
131
+ super().__init__()
132
+ self.image_end_token_id = image_end_token_id
133
+ self.eoc_token_id = eoc_token_id
134
+ self.media_token_id = media_token_id
135
+ self.use_media_placement_augmentation = use_media_placement_augmentation
136
+ self.vis_dim = vis_dim
137
+ self.lang_dim = lang_dim
138
+ # inner_dim = self.lang_dim * 4
139
+ # self.vis_proj = nn.Sequential(
140
+ # nn.LayerNorm(self.vis_dim),
141
+ # nn.Linear(self.vis_dim, inner_dim, bias=False),
142
+ # nn.GELU(),
143
+ # nn.Linear(inner_dim, self.lang_dim, bias=False),
144
+ # )
145
+ self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim)
146
+ self.vision_encoder = vision_encoder
147
+ self.num_positions = vis_embed_size
148
+ self.lang_encoder = lang_encoder
149
+ self.lang_encoder.init_flamingo(
150
+ media_token_id=media_token_id,
151
+ use_media_placement_augmentation=self.use_media_placement_augmentation,
152
+ )
153
+ first_layer = self.lang_encoder._get_decoder_layers()[0]
154
+ first_layer.add_visual_token = add_visual_token
155
+ first_layer.visual_token_id = visual_token_id
156
+ first_layer.media_token_id = media_token_id
157
+ first_layer.box_token_id = box_token_id
158
+ # first_layer.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
159
+ # assert not (add_pe and add_relation)
160
+ # self.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
161
+ # first_layer.pos_enc = self.pos_enc
162
+ self.box_token_id = box_token_id
163
+ self.prebox_token_id = prebox_token_id
164
+ self.media_token_id = media_token_id
165
+ self.visual_token_id = visual_token_id
166
+ self.previsual_token_id = previsual_token_id
167
+ self.hidden_state_dim = hidden_state_dim
168
+ self.image_size = image_size
169
+ self.patch_size = patch_size
170
+ self.patch_num = self.image_size // self.patch_size
171
+ self.detection_head = YOLOXHead(
172
+ num_classes=1,
173
+ strides=[patch_size],
174
+ in_channels=[self.hidden_state_dim + self.lang_dim],
175
+ )
176
+ self.use_format_v2 = use_format_v2
177
+ self.nothing_token_id = nothing_token_id
178
+ self.roi_align = roi_align
179
+ self.roi_output_size = roi_output_size if roi_align else None
180
+ self.apply_mask = apply_mask
181
+ self.endofobject_token_id = endofobject_token_id
182
+
183
+
184
+ def _get_detection_batch(
185
+ self,
186
+ visual_token_id,
187
+ previsual_token_id,
188
+ input_ids: torch.Tensor,
189
+ hidden_states: torch.Tensor,
190
+ added_bbox_list,
191
+ box_num = 100,
192
+ ):
193
+ select_mask = torch.logical_or(input_ids == visual_token_id, input_ids == previsual_token_id)
194
+ visual_token_position = select_mask.nonzero()
195
+ visual_token_hidden_states = hidden_states[select_mask]
196
+ prev_batch_idx = -1
197
+ media_idx = []
198
+ cnt = 0
199
+ assert len(visual_token_hidden_states) == len(visual_token_position)
200
+ if len(added_bbox_list) != len(visual_token_position):
201
+ msg = f"ERROR: {len(added_bbox_list)}:{len(visual_token_position)}\n{added_bbox_list}\n{visual_token_position}"
202
+ logging.info(msg)
203
+ alpha = 0.0
204
+ else:
205
+ alpha = 1.0
206
+ visual_batches = []
207
+ previsual_batches = []
208
+ for (batch_idx, idx), visual_token_hidden_state, bbox in zip(
209
+ visual_token_position, visual_token_hidden_states, added_bbox_list,
210
+ ):
211
+ # ! VERY IMPORTANT BUG !
212
+ bbox = bbox.clone()
213
+ # ! VERY IMPORTANT BUG !
214
+ batch_idx = batch_idx.item()
215
+ idx = idx.item()
216
+ if batch_idx != prev_batch_idx:
217
+ prev_batch_idx = batch_idx
218
+ this_input_ids = input_ids[batch_idx]
219
+ cnt += len(media_idx)
220
+ media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
221
+ for i in range(len(media_idx)):
222
+ if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
223
+ break
224
+ image_index = cnt + i
225
+ size = int(self.image_embedding[image_index].shape[0] ** 0.5)
226
+ image_embedding = self.image_embedding[image_index]
227
+ # inplace xyxy2cxcywh
228
+ # print(bbox)
229
+ # TODO: CHECK self.image_size. Is it 224?
230
+ bbox = xyxy2cxcywh(bbox) * self.image_size
231
+ # print(bbox)
232
+ concat_image_visual_embedding = torch.cat([image_embedding, visual_token_hidden_state.unsqueeze(0).repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1)
233
+ label = torch.cat([torch.zeros(bbox.shape[0], 1, device=bbox.device), bbox], dim=-1)
234
+ label = torch.cat([label, torch.zeros(box_num - label.shape[0], label.shape[1], device=label.device)], dim=0)
235
+ if input_ids[batch_idx, idx] == previsual_token_id:
236
+ previsual_batches.append([concat_image_visual_embedding, label])
237
+ elif input_ids[batch_idx, idx] == visual_token_id:
238
+ visual_batches.append([concat_image_visual_embedding, label])
239
+ else:
240
+ logging.info(f"WARNING... NOT visual nor previsual. it is {input_ids[batch_idx, idx]}")
241
+ return visual_batches, previsual_batches, alpha, alpha
242
+
243
+ def get_detection_losses(
244
+ self,
245
+ input_ids: torch.Tensor,
246
+ hidden_states: torch.Tensor,
247
+ added_bbox_list,
248
+ box_num = 100,
249
+ ):
250
+ visual_token_batches, previsual_token_batches, alpha1, alpha2 = self._get_detection_batch(
251
+ visual_token_id=self.visual_token_id,
252
+ previsual_token_id=self.previsual_token_id,
253
+ input_ids=input_ids,
254
+ hidden_states=hidden_states,
255
+ added_bbox_list=added_bbox_list,
256
+ box_num=box_num,
257
+ )
258
+ loss_dict = []
259
+ for batches, alpha in zip([visual_token_batches, previsual_token_batches], [alpha1, alpha2]):
260
+ # x: [B, C, H, W]
261
+ if len(batches) != 0:
262
+ x = torch.cat([batch[0].unsqueeze(0) for batch in batches], dim=0).permute(0,3,1,2)
263
+ labels = torch.cat([batch[1].unsqueeze(0) for batch in batches], dim=0)
264
+ else:
265
+ x = None
266
+ labels = None
267
+ if x is not None:
268
+ losses = self.detection_head(xin=[x], labels=labels)
269
+ loss, loss_iou, loss_obj, loss_cls, loss_l1, _ = losses
270
+ else:
271
+ loss = torch.tensor(0.0).cuda()
272
+ loss_iou = loss
273
+ loss_obj = loss
274
+ loss_cls = loss
275
+ loss_l1 = loss
276
+
277
+ loss_dict.append(dict(
278
+ loss=loss * alpha,
279
+ loss_iou=loss_iou * alpha,
280
+ loss_obj=loss_obj * alpha,
281
+ loss_cls=loss_cls * alpha,
282
+ loss_l1=loss_l1 * alpha,
283
+ ))
284
+ ret_loss = {}
285
+ for key in loss_dict[0].keys():
286
+ ret_loss[key] = 0.0
287
+ for d in loss_dict:
288
+ ret_loss[key] += d[key]
289
+ return ret_loss, loss_dict
290
+
291
+ def get_detection_result(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ hidden_states: torch.Tensor,
295
+ nms_thr: float = 0.45,
296
+ score_thr: float = 0.01,
297
+ debug_id: int = 0,
298
+ debug_mode: bool = False,
299
+ ):
300
+ assert len(input_ids) == 1, "only batch size = 1 is supported yet"
301
+ # assert len(self.image_embedding) == 1, "only one image is supported yet"
302
+ # assert (input_ids[..., -1] == self.visual_token_id).all(), "the last token should be visual token"
303
+ visual_token_hidden_state = hidden_states[..., -1, :]
304
+ boxes_list = []
305
+ scores_list = []
306
+ for image_embedding in self.image_embedding:
307
+ size = int(image_embedding.shape[0] ** 0.5)
308
+ x = torch.cat([image_embedding, visual_token_hidden_state.repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1).unsqueeze(0).permute(0,3,1,2)
309
+ with torch.no_grad():
310
+ outputs = self.detection_head(xin=[x], labels=None)
311
+ boxes = outputs[0,:,:4].cpu().numpy()
312
+ scores = outputs[0,:,4].cpu().numpy()
313
+ scores_mask = scores > score_thr
314
+ boxes = boxes[scores_mask]
315
+ boxes = cxcywh2xyxy(boxes)
316
+ scores = scores[scores_mask]
317
+ keep = nms(boxes, scores, nms_thr=nms_thr)
318
+ boxes = boxes[keep]
319
+ scores = scores[keep]
320
+ if debug_mode:
321
+ obj_heatmap = outputs[0,:, -2].reshape(size, size).cpu().numpy()
322
+ import matplotlib.pyplot as plt
323
+ import seaborn as sns
324
+ plt.figure()
325
+ sns_plot = sns.heatmap(obj_heatmap)
326
+ plt.savefig(f"heatmap_{debug_id}.jpg")
327
+ debug_id += 1
328
+ boxes_list.append(boxes)
329
+ scores_list.append(scores)
330
+ if len(boxes_list) == 1:
331
+ boxes_list = boxes_list[0]
332
+ scores_list = scores_list[0]
333
+ return boxes_list, scores_list
334
+
335
+ def _condition_attention(self, loc_list = None):
336
+ for i in range(len(self.lang_encoder.gpt_neox.layers)):
337
+ self.lang_encoder.gpt_neox.layers[i].decoder_layer.attention.loc_list = loc_list
338
+
339
+ def forward(
340
+ self,
341
+ vision_x: torch.Tensor,
342
+ lang_x: torch.Tensor,
343
+ attention_mask: torch.Tensor = None,
344
+ labels: torch.Tensor = None,
345
+ use_cached_vision_x: bool = False,
346
+ clear_conditioned_layers: bool = True,
347
+ past_key_values=None,
348
+ use_cache: bool = False,
349
+ image_nums=None,
350
+ image_start_index_list=None,
351
+ added_bbox_list=None,
352
+ add_box: bool = False,
353
+ relations=None,
354
+ debug_mode: bool = False,
355
+ ):
356
+ """
357
+ Forward pass of Flamingo.
358
+
359
+ Args:
360
+ vision_x (torch.Tensor): Vision input
361
+ shape (B, T_img, F, C, H, W) with F=1
362
+ lang_x (torch.Tensor): Language input ids
363
+ shape (B, T_txt)
364
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
365
+ labels (torch.Tensor, optional): Labels. Defaults to None.
366
+ clear_conditioned_layers: if True, clear the conditioned layers
367
+ once the foward pass is completed. Set this to false if the
368
+ same set of images will be reused in another subsequent
369
+ forward pass.
370
+ past_key_values: pre-computed values to pass to language model.
371
+ See past_key_values documentation in Hugging Face
372
+ CausalLM models.
373
+ use_cache: whether to use cached key values. See use_cache
374
+ documentation in Hugging Face CausalLM models.
375
+ """
376
+ self.valid = True
377
+ self.lang_encoder.loc_list = None
378
+ if use_cached_vision_x:
379
+ # Case: use cached; vision_x should be cached and other
380
+ # vision-related inputs should not be provided.
381
+ assert (
382
+ vision_x is None
383
+ ), "Expect vision_x to be None when use_cached_vision_x is True."
384
+ assert self.lang_encoder.is_conditioned()
385
+ else:
386
+ # Case: do not use caching (i.e. this is a standard forward pass);
387
+ self._encode_vision_x(
388
+ vision_x=vision_x,
389
+ image_nums=image_nums,
390
+ image_start_index_list=image_start_index_list,
391
+ added_bbox_list=added_bbox_list if add_box else None,
392
+ input_ids=lang_x,
393
+ relations=relations,
394
+ )
395
+ if self.apply_mask:
396
+ if self.roi_align:
397
+ attend_length = 1 + self.roi_output_size ** 2
398
+ else:
399
+ attend_length = 2
400
+ prebox_loc = (lang_x == self.prebox_token_id).nonzero()
401
+ loc_list = []
402
+ for (x, y) in prebox_loc:
403
+ x = x.item()
404
+ y = y.item()
405
+ for yy in range(y+1, lang_x.shape[1]):
406
+ if lang_x[x, yy] == self.endofobject_token_id:
407
+ # [batch_idx, [previsual:prebox], [object:endofobject-1]]
408
+ loc_list.append([x, [y-attend_length+1, y], [y+1, yy-1]])
409
+ self._condition_attention(loc_list=loc_list)
410
+ else:
411
+ self._condition_attention(None)
412
+
413
+ output = self.lang_encoder(
414
+ input_ids=lang_x,
415
+ attention_mask=attention_mask,
416
+ labels=labels,
417
+ past_key_values=past_key_values,
418
+ use_cache=use_cache,
419
+ output_hidden_states=True,
420
+ )
421
+ if vision_x is None:
422
+ output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean()
423
+
424
+ hidden_states = output["hidden_states"][-1]
425
+ if self.training and added_bbox_list is not None:
426
+ detection_losses, loss_dict = self.get_detection_losses(
427
+ input_ids=lang_x,
428
+ hidden_states=hidden_states,
429
+ added_bbox_list=added_bbox_list,
430
+ )
431
+ output["detection_losses"] = detection_losses
432
+ output["loss_dict"] = loss_dict
433
+ elif labels is None:
434
+ boxes, scores = self.get_detection_result(
435
+ input_ids=lang_x,
436
+ hidden_states=hidden_states,
437
+ debug_id=self.debug_id if hasattr(self, "debug_id") else None,
438
+ debug_mode=debug_mode,
439
+ )
440
+ output["boxes"] = boxes
441
+ output["scores"] = scores
442
+
443
+ if clear_conditioned_layers:
444
+ self.lang_encoder.clear_conditioned_layers()
445
+ self._condition_attention(None)
446
+ return output
447
+
448
+ def generate(
449
+ self,
450
+ vision_x: torch.Tensor,
451
+ lang_x: torch.Tensor,
452
+ attention_mask: torch.Tensor = None,
453
+ added_bbox_list=None,
454
+ num_beams=1,
455
+ max_new_tokens=None,
456
+ temperature=1.0,
457
+ top_k=0,
458
+ top_p=1.0,
459
+ no_repeat_ngram_size=0,
460
+ prefix_allowed_tokens_fn=None,
461
+ length_penalty=1.0,
462
+ num_return_sequences=1,
463
+ do_sample=False,
464
+ early_stopping=False,
465
+ bad_words_ids=None,
466
+ force_words_ids=None,
467
+ image_start_index_list=None,
468
+ image_nums=None,
469
+ min_length=None,
470
+ return_dict_in_generate=False,
471
+ output_hidden_states=False,
472
+ output_scores=False,
473
+ logits_processor_list=None,
474
+ eos_token_id=None,
475
+ ):
476
+ """
477
+ Generate text conditioned on vision and language inputs.
478
+
479
+ Args:
480
+ vision_x (torch.Tensor): Vision input
481
+ shape (B, T_img, F, C, H, W)
482
+ images in the same chunk are collated along T_img, and frames are collated along F
483
+ currently only F=1 is supported (single-frame videos)
484
+ lang_x (torch.Tensor): Language input
485
+ shape (B, T_txt)
486
+ max_length (int, optional): Maximum length of the output. Defaults to None.
487
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
488
+ num_beams (int, optional): Number of beams. Defaults to 1.
489
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
490
+ temperature (float, optional): Temperature. Defaults to 1.0.
491
+ top_k (int, optional): Top k. Defaults to 0.
492
+ top_p (float, optional): Top p. Defaults to 1.0.
493
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
494
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
495
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
496
+ do_sample (bool, optional): Do sample. Defaults to False.
497
+ early_stopping (bool, optional): Early stopping. Defaults to False.
498
+ Returns:
499
+ torch.Tensor: lang_x with generated tokens appended to it
500
+ """
501
+ if num_beams > 1:
502
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
503
+ image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist()
504
+ image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist()
505
+ if added_bbox_list is not None and len(added_bbox_list) != 0:
506
+ added_bbox_list = added_bbox_list * num_beams
507
+
508
+ self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams, added_bbox_list=added_bbox_list, input_ids=lang_x.repeat_interleave(num_beams, dim=0))
509
+
510
+ if logits_processor_list is not None:
511
+ assert isinstance(logits_processor_list, list)
512
+ logits_processor_list = LogitsProcessorList(logits_processor_list)
513
+ output = self.lang_encoder.generate(
514
+ input_ids=lang_x,
515
+ attention_mask=attention_mask,
516
+ eos_token_id=(self.eoc_token_id) if eos_token_id is None else eos_token_id,
517
+ num_beams=num_beams,
518
+ max_new_tokens=max_new_tokens,
519
+ min_length=min_length,
520
+ length_penalty=length_penalty,
521
+ logits_processor=logits_processor_list,
522
+ return_dict_in_generate=return_dict_in_generate,
523
+ output_scores=output_scores,
524
+ )
525
+ self.lang_encoder.clear_conditioned_layers()
526
+ return output
527
+
528
+ def _get_data_list_and_visual_tokens(
529
+ self,
530
+ all_box_list,
531
+ box_token_id,
532
+ prebox_token_id,
533
+ input_ids,
534
+ vision_x,
535
+ nothing_embedding = None,
536
+ ):
537
+ box_locations = (torch.logical_or(input_ids == box_token_id, input_ids == prebox_token_id)).nonzero()
538
+ prev_batch_idx = -1
539
+ media_idx = []
540
+ cnt = 0
541
+ data_list = []
542
+ visual_tokens = []
543
+ if len(all_box_list) != len(box_locations):
544
+ logging.info(f"WARNING. len(all_box_list) != len(box_locations) {len(all_box_list)} vs {len(box_locations)}")
545
+ self.valid = False
546
+ for III, (batch_idx, idx) in enumerate(box_locations):
547
+ batch_idx = batch_idx.item()
548
+ idx = idx.item()
549
+ if batch_idx != prev_batch_idx:
550
+ prev_batch_idx = batch_idx
551
+ this_input_ids = input_ids[batch_idx]
552
+ cnt += len(media_idx)
553
+ media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
554
+ for i in range(len(media_idx)):
555
+ if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
556
+ break
557
+ image_index = cnt + i
558
+ size = int(vision_x[image_index].shape[0] ** 0.5)
559
+ image_feature = vision_x[image_index].reshape(size, size, -1)
560
+ try:
561
+ raw_xyxy = all_box_list[III]
562
+ except:
563
+ logging.info("out of scope for all_box_list")
564
+ raw_xyxy = all_box_list[-1]
565
+ region_xyxy = np.array(raw_xyxy) * size
566
+ x1, y1, x2, y2 = region_xyxy.astype(int).clip(0, size-1).tolist()
567
+ x2 = max(x1, x2)
568
+ y2 = max(y1, y2)
569
+ if x1 + y1 + x2 + y2 == 0.0 and nothing_embedding is not None:
570
+ visual_token = nothing_embedding
571
+ else:
572
+ if self.roi_align:
573
+ visual_token = torchvision.ops.roi_align(
574
+ image_feature.permute(2, 0, 1).unsqueeze(0),
575
+ [torch.tensor(region_xyxy.astype(np.float32)).unsqueeze(0).cuda()],
576
+ output_size=self.roi_output_size,
577
+ spatial_scale=1.0,
578
+ )
579
+ visual_token = visual_token.squeeze(0).flatten(1).permute(1, 0)
580
+ else:
581
+ visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0)
582
+ box = torch.tensor([0] + raw_xyxy, device=visual_token.device, dtype=visual_token.dtype)
583
+ data_list.append([visual_token, box, batch_idx, idx, i])
584
+ visual_tokens.append(visual_token)
585
+ return data_list, visual_tokens
586
+
587
+ def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, num_beams=None, input_ids=None, relations=None):
588
+ """
589
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
590
+ Args:
591
+ vision_x (torch.Tensor): Vision input
592
+ shape (B, T_img, F, C, H, W)
593
+ Images in the same chunk are collated along T_img, and frames are collated along F
594
+ Currently only F=1 is supported (single-frame videos)
595
+
596
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
597
+ """
598
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
599
+ b, T, F = vision_x.shape[:3]
600
+ assert F == 1, "Only single frame supported"
601
+
602
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
603
+ if hasattr(self.vision_encoder, "visual"):
604
+ vision_x = self.vision_encoder.visual(vision_x)[1]
605
+ else:
606
+ vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1)
607
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
608
+
609
+ # print(vision_x[0,0,0])
610
+ # # DEBUG HERE
611
+ # if torch.distributed.get_rank() == 0:
612
+ # import pdb; pdb.set_trace()
613
+ # else:
614
+ # torch.distributed.barrier()
615
+ vision_x = vision_x.mean(2)
616
+ # vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
617
+ # vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0)
618
+ vision_x = self.vis_proj(vision_x).squeeze(1)
619
+ self.image_embedding = vision_x
620
+
621
+ data_list = None
622
+ visual_tokens = None
623
+ if added_bbox_list is not None and input_ids is not None:
624
+ all_box_list = added_bbox_list[0].tolist()
625
+ for list in added_bbox_list[1:]:
626
+ all_box_list.extend(list.tolist())
627
+ data_list, visual_tokens = self._get_data_list_and_visual_tokens(
628
+ all_box_list=all_box_list,
629
+ box_token_id=self.box_token_id,
630
+ prebox_token_id=self.prebox_token_id,
631
+ input_ids=input_ids,
632
+ vision_x=vision_x,
633
+ nothing_embedding=self.lang_encoder.gpt_neox.embed_in(torch.tensor(self.nothing_token_id).to(self.lang_encoder.gpt_neox.embed_in.weight.device)) if self.nothing_token_id is not None else None,
634
+ )
635
+
636
+ first_layer = self.lang_encoder._get_decoder_layers()[0]
637
+ first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, num_beams=num_beams, visual_tokens=visual_tokens, data_list=[[d[2], d[3]] for d in data_list] if data_list is not None else data_list)
multimodal/build/lib/open_flamingo/src/flamingo_lm.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .helpers import GatedCrossAttentionBlock
7
+ from .utils import getattr_recursive, setattr_recursive
8
+
9
+
10
+ class FlamingoLayer(nn.Module):
11
+ def __init__(self, decoder_layer):
12
+ super().__init__()
13
+ self.decoder_layer = decoder_layer
14
+ self.vis_x = None
15
+ self.image_nums = None
16
+ self.image_start_index_list = None
17
+ self.media_locations = None
18
+ self.add_visual_token = False
19
+ self.input_ids = None
20
+
21
+ def is_conditioned(self) -> bool:
22
+ """Check whether the layer is conditioned."""
23
+ return self.vis_x is not None
24
+
25
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
26
+ def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None):
27
+ self.vis_x = vis_x
28
+ self.image_nums = image_nums
29
+ self.image_start_index_list = image_start_index_list
30
+ self.num_beams = num_beams
31
+ self.visual_tokens = visual_tokens
32
+ self.data_list = data_list
33
+ self.input_ids = None
34
+
35
+
36
+ def condition_media_locations(self, media_locations):
37
+ self.media_locations = media_locations
38
+
39
+ def condition_attend_previous(self, attend_previous):
40
+ self.attend_previous = attend_previous
41
+
42
+ def forward(
43
+ self,
44
+ hidden_states, # alignment with hugging face name
45
+ attention_mask=None,
46
+ **decoder_layer_kwargs,
47
+ ):
48
+ if self.media_locations is None:
49
+ raise ValueError("media_locations must be conditioned before forward pass")
50
+
51
+ if self.vis_x is not None:
52
+ if self.training:
53
+ single_length = self.vis_x.shape[-2]
54
+ image_nums = self.image_nums
55
+ image_start_index_list = self.image_start_index_list
56
+ image_nums = [0] + np.cumsum(image_nums).tolist()
57
+ for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
58
+ for index in start_indices:
59
+ if image_num_begin < image_num_end:
60
+ hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
61
+ image_num_begin += 1
62
+
63
+ if self.visual_tokens is not None and len(self.visual_tokens) != 0:
64
+ for i, (x, y) in enumerate(self.data_list):
65
+ if len(self.visual_tokens[i].shape) > 1:
66
+ # print(self.visual_tokens[i].shape[0], "embedding")
67
+ hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
68
+ else:
69
+ # print(self.visual_tokens[i].shape[0], "embedding")
70
+ hidden_states[x, y] = self.visual_tokens[i]
71
+
72
+ elif not self.training:
73
+ if (
74
+ ("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
75
+ ("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
76
+ ):
77
+ single_length = self.vis_x.shape[-2]
78
+ image_nums = self.image_nums
79
+ image_start_index_list = self.image_start_index_list
80
+ image_nums = [0] + np.cumsum(image_nums).tolist()
81
+ for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
82
+ for index in start_indices:
83
+ if image_num_begin < image_num_end:
84
+ hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
85
+ image_num_begin += 1
86
+ if self.visual_tokens is not None and len(self.visual_tokens) != 0:
87
+ for i, (x, y) in enumerate(self.data_list):
88
+ # import pdb; pdb.set_trace()
89
+ # print(x, y, self.visual_tokens[i].shape)
90
+ if len(self.visual_tokens[i].shape) > 1:
91
+ # print(self.visual_tokens[i].shape[0], "embedding")
92
+ hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
93
+ else:
94
+ # print(self.visual_tokens[i].shape[0], "embedding")
95
+ hidden_states[x, y] = self.visual_tokens[i]
96
+ hidden_states = self.decoder_layer(
97
+ hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
98
+ )
99
+ return hidden_states
100
+
101
+
102
+ class FlamingoLMMixin(nn.Module):
103
+ """
104
+ Mixin to add cross-attention layers to a language model.
105
+ """
106
+
107
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
108
+ self.decoder_layers_attr_name = decoder_layers_attr_name
109
+
110
+ def _get_decoder_layers(self):
111
+ return getattr_recursive(self, self.decoder_layers_attr_name)
112
+
113
+ def _set_decoder_layers(self, value):
114
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
115
+
116
+ def init_flamingo(
117
+ self,
118
+ media_token_id,
119
+ use_media_placement_augmentation,
120
+ ):
121
+ """
122
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
123
+ """
124
+ self._set_decoder_layers(
125
+ nn.ModuleList(
126
+ [FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
127
+ )
128
+ )
129
+ self.media_token_id = media_token_id
130
+ self.use_media_placement_augmentation = use_media_placement_augmentation
131
+ self.initialized_flamingo = True
132
+
133
+ def forward(self, *input, **kwargs):
134
+ """Condition the Flamingo layers on the media locations before forward()"""
135
+ if not self.initialized_flamingo:
136
+ raise ValueError(
137
+ "Flamingo layers are not initialized. Please call `init_flamingo` first."
138
+ )
139
+
140
+ input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
141
+ media_locations = input_ids == self.media_token_id
142
+ attend_previous = (
143
+ (random.random() < 0.5) if self.use_media_placement_augmentation else True
144
+ )
145
+
146
+ if (
147
+ "gpt2" in self.__class__.__name__.lower()
148
+ or "codegen" in self.__class__.__name__.lower()
149
+ ):
150
+ for layer in self.transformer.h:
151
+ layer.condition_media_locations(media_locations)
152
+ layer.condition_attend_previous(attend_previous)
153
+ elif "gptneox" in self.__class__.__name__.lower():
154
+ for layer in self.gpt_neox.layers:
155
+ layer.condition_media_locations(media_locations)
156
+ layer.condition_attend_previous(attend_previous)
157
+ else:
158
+ for layer in self.get_decoder().layers:
159
+ layer.condition_media_locations(media_locations)
160
+ layer.condition_attend_previous(attend_previous)
161
+ return super().forward(
162
+ *input, **kwargs
163
+ ) # Call the other parent's forward method
164
+
165
+ def is_conditioned(self) -> bool:
166
+ """Check whether all decoder layers are already conditioned."""
167
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
168
+
169
+ def clear_conditioned_layers(self):
170
+ for layer in self._get_decoder_layers():
171
+ layer.condition_vis_x(None)
172
+ layer.condition_media_locations(None)
173
+ layer.condition_attend_previous(None)
multimodal/build/lib/open_flamingo/src/gcn.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.parameter import Parameter
5
+ import math
6
+ from torch.autograd import Variable
7
+ from torchvision.ops import box_iou
8
+
9
+
10
+
11
+ class GraphConvolution(nn.Module):
12
+ """
13
+ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
14
+ """
15
+
16
+ def __init__(self, in_features, out_features, bias=True, skip=True):
17
+ super(GraphConvolution, self).__init__()
18
+ self.skip = skip
19
+ self.in_features = in_features
20
+ self.out_features = out_features
21
+ self.weight = Parameter(torch.Tensor(in_features, out_features))
22
+ if bias:
23
+ self.bias = Parameter(torch.Tensor(out_features))
24
+ else:
25
+ self.register_parameter('bias', None)
26
+ self.reset_parameters()
27
+
28
+ def reset_parameters(self):
29
+ stdv = 1. / math.sqrt(self.weight.size(1))
30
+ self.weight.data.uniform_(-stdv, stdv)
31
+ if self.bias is not None:
32
+ self.bias.data.uniform_(-stdv, stdv)
33
+
34
+ def forward(self, input, adj):
35
+ # TODO make fc more efficient via "pack_padded_sequence"
36
+ # import ipdb; ipdb.set_trace()
37
+ support = torch.bmm(input, self.weight.unsqueeze(
38
+ 0).expand(input.shape[0], -1, -1))
39
+ output = torch.bmm(adj, support)
40
+ #output = SparseMM(adj)(support)
41
+ if self.bias is not None:
42
+ output += self.bias.unsqueeze(0).expand(input.shape[0], -1, -1)
43
+ if self.skip:
44
+ output += support
45
+
46
+ return output
47
+
48
+ def __repr__(self):
49
+ return self.__class__.__name__ + ' (' \
50
+ + str(self.in_features) + ' -> ' \
51
+ + str(self.out_features) + ')'
52
+
53
+
54
+ class GCN_sim(nn.Module):
55
+ def __init__(self, dim_in, dim_hidden, dim_out, dropout, num_layers):
56
+ super(GCN_sim, self).__init__()
57
+ assert num_layers >= 1
58
+ self.fc_k = nn.Linear(dim_in, dim_hidden)
59
+ self.fc_q = nn.Linear(dim_in, dim_hidden)
60
+
61
+ dim_hidden = dim_out if num_layers == 1 else dim_hidden
62
+ self.gcs = nn.ModuleList([
63
+ GraphConvolution(dim_in, dim_hidden)
64
+ ])
65
+
66
+ for i in range(num_layers - 1):
67
+ dim_tmp = dim_out if i == num_layers-2 else dim_hidden
68
+ self.gcs.append(GraphConvolution(dim_hidden, dim_tmp))
69
+
70
+ self.dropout = dropout
71
+
72
+ def construct_graph(self, x, length):
73
+ # TODO make fc more efficient via "pack_padded_sequence"
74
+ emb_k = self.fc_k(x)
75
+ emb_q = self.fc_q(x)
76
+
77
+ s = torch.bmm(emb_k, emb_q.transpose(1, 2))
78
+
79
+ s_mask = s.data.new(*s.size()).fill_(1).bool() # [B, T1, T2]
80
+ # Init similarity mask using lengths
81
+ for i, (l_1, l_2) in enumerate(zip(length, length)):
82
+ s_mask[i][:l_1, :l_2] = 0
83
+ s_mask = Variable(s_mask)
84
+ s.data.masked_fill_(s_mask.data, -float("inf"))
85
+
86
+ a_weight = F.softmax(s, dim=2) # [B, t1, t2]
87
+ # remove nan from softmax on -inf
88
+ a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)
89
+
90
+ return a_weight
91
+
92
+ def forward(self, x, length):
93
+ adj_sim = self.construct_graph(x, length)
94
+
95
+ for gc in self.gcs:
96
+ x = F.relu(gc(x, adj_sim))
97
+ x = F.dropout(x, self.dropout, training=self.training)
98
+
99
+ return x
100
+
101
+
102
+ class GCN(nn.Module):
103
+ def __init__(self, dim_in, dim_hidden, dim_out, dropout, mode, skip, num_layers, ST_n_next=None):
104
+ super(GCN, self).__init__()
105
+ assert len(mode) != 0
106
+ self.mode = mode
107
+ self.skip = skip
108
+
109
+ if "GCN_sim" in mode:
110
+ self.GCN_sim = GCN_sim(
111
+ dim_in, dim_hidden, dim_out, dropout, num_layers)
112
+
113
+ def forward(self, x, length):
114
+
115
+ out = []
116
+ if "GCN_sim" in self.mode:
117
+ out.append(self.GCN_sim(x, length))
118
+
119
+ out = sum(out)
120
+ if self.skip:
121
+ out += x
122
+
123
+ return out
124
+
125
+
126
+ if __name__ == '__main__':
127
+ model = GCN(512, 128, 512, 0.5, mode=[
128
+ "GCN_sim"], skip=True, num_layers=3, ST_n_next=3)
129
+ bs, T, N = 10, 5, 10
130
+ n_node = T*N
131
+
132
+ input = torch.rand(bs, n_node, 512)
133
+ length = torch.ones((bs))
134
+ length = length.type(torch.IntTensor)
135
+ bboxes = torch.rand((bs, 5, 10, 4))
136
+
137
+ output = model(input, length)
multimodal/build/lib/open_flamingo/src/helpers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from einops_exts import rearrange_many
8
+ from torch import einsum, nn
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def FeedForward(dim, mult=4):
16
+ inner_dim = int(dim * mult)
17
+ return nn.Sequential(
18
+ nn.LayerNorm(dim),
19
+ nn.Linear(dim, inner_dim, bias=False),
20
+ nn.GELU(),
21
+ nn.Linear(inner_dim, dim, bias=False),
22
+ )
23
+
24
+
25
+ class PerceiverAttention(nn.Module):
26
+ def __init__(self, *, dim, dim_head=64, heads=8):
27
+ super().__init__()
28
+ self.scale = dim_head**-0.5
29
+ self.heads = heads
30
+ inner_dim = dim_head * heads
31
+
32
+ self.norm_media = nn.LayerNorm(dim)
33
+ self.norm_latents = nn.LayerNorm(dim)
34
+
35
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
36
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
38
+
39
+ def forward(self, x, latents):
40
+ """
41
+ Args:
42
+ x (torch.Tensor): image features
43
+ shape (b, T, n1, D)
44
+ latent (torch.Tensor): latent features
45
+ shape (b, T, n2, D)
46
+ """
47
+ x = self.norm_media(x)
48
+ latents = self.norm_latents(latents)
49
+
50
+ h = self.heads
51
+
52
+ q = self.to_q(latents)
53
+ kv_input = torch.cat((x, latents), dim=-2)
54
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
56
+ q = q * self.scale
57
+
58
+ # attention
59
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
60
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
61
+ attn = sim.softmax(dim=-1)
62
+
63
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
64
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
65
+ return self.to_out(out)
66
+
67
+
68
+ class PerceiverResampler(nn.Module):
69
+ def __init__(
70
+ self,
71
+ *,
72
+ dim,
73
+ depth=6,
74
+ dim_head=64,
75
+ heads=8,
76
+ num_latents=64,
77
+ max_num_media=None,
78
+ max_num_frames=None,
79
+ ff_mult=4,
80
+ ):
81
+ super().__init__()
82
+ assert False, "Do not use PerceiverResampler"
83
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
84
+ self.frame_embs = (
85
+ nn.Parameter(torch.randn(max_num_frames, dim))
86
+ if exists(max_num_frames)
87
+ else None
88
+ )
89
+ self.media_time_embs = (
90
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
91
+ if exists(max_num_media)
92
+ else None
93
+ )
94
+
95
+ self.layers = nn.ModuleList([])
96
+ for _ in range(depth):
97
+ self.layers.append(
98
+ nn.ModuleList(
99
+ [
100
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
101
+ FeedForward(dim=dim, mult=ff_mult),
102
+ ]
103
+ )
104
+ )
105
+
106
+ self.norm = nn.LayerNorm(dim)
107
+
108
+ def forward(self, x):
109
+ """
110
+ Args:
111
+ x (torch.Tensor): image features
112
+ shape (b, T, F, v, D)
113
+ Returns:
114
+ shape (b, T, n, D) where n is self.num_latents
115
+ """
116
+ b, T, F, v = x.shape[:4]
117
+
118
+ # frame and media time embeddings
119
+ if exists(self.frame_embs):
120
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
121
+ x = x + frame_embs
122
+ x = rearrange(
123
+ x, "b T F v d -> b T (F v) d"
124
+ ) # flatten the frame and spatial dimensions
125
+ if exists(self.media_time_embs):
126
+ x = x + self.media_time_embs[:T]
127
+
128
+ # blocks
129
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
130
+ for attn, ff in self.layers:
131
+ latents = attn(x, latents) + latents
132
+ latents = ff(latents) + latents
133
+ return self.norm(latents)
134
+
135
+
136
+ # gated cross attention
137
+
138
+
139
+ class MaskedCrossAttention(nn.Module):
140
+ def __init__(
141
+ self,
142
+ *,
143
+ dim,
144
+ dim_visual,
145
+ dim_head=64,
146
+ heads=8,
147
+ only_attend_immediate_media=True,
148
+ ):
149
+ super().__init__()
150
+ self.scale = dim_head**-0.5
151
+ self.heads = heads
152
+ inner_dim = dim_head * heads
153
+
154
+ self.norm = nn.LayerNorm(dim)
155
+
156
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
157
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
158
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
159
+
160
+ # whether for text to only attend to immediate preceding image, or all previous images
161
+ self.only_attend_immediate_media = only_attend_immediate_media
162
+
163
+ def forward(self, x, media, media_locations=None, attend_previous=True):
164
+ """
165
+ Args:
166
+ x (torch.Tensor): text features
167
+ shape (B, T_txt, D_txt)
168
+ media (torch.Tensor): image features
169
+ shape (B, T_img, n, D_img) where n is the dim of the latents
170
+ media_locations: boolean mask identifying the media tokens in x
171
+ shape (B, T_txt)
172
+ attend_previous: bool
173
+ If false, ignores immediately preceding image and starts attending when following image
174
+ """
175
+ assert attend_previous, "text must attend to the image that before it"
176
+
177
+ _, T_img, n = media.shape[:3]
178
+ h = self.heads
179
+
180
+ x = self.norm(x)
181
+
182
+ q = self.to_q(x)
183
+ media = rearrange(media, "b t n d -> b (t n) d")
184
+
185
+ k, v = self.to_kv(media).chunk(2, dim=-1)
186
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
187
+
188
+ q = q * self.scale
189
+
190
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
191
+
192
+ if exists(media_locations):
193
+ # at each boolean of True, increment the time counter (relative to media time)
194
+ text_time = media_locations.cumsum(dim=-1)
195
+ media_time = torch.arange(T_img, device=x.device) + 1
196
+
197
+ if not attend_previous:
198
+ text_time[~media_locations] += 1
199
+ # make sure max is still the number of images in the sequence
200
+ text_time[
201
+ text_time
202
+ > repeat(
203
+ torch.count_nonzero(media_locations, dim=1),
204
+ "b -> b i",
205
+ i=text_time.shape[1],
206
+ )
207
+ ] = 0
208
+
209
+ # text time must equal media time if only attending to most immediate image
210
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
211
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
212
+
213
+ text_to_media_mask = mask_op(
214
+ rearrange(text_time, "b i -> b 1 i 1"),
215
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
216
+ )
217
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
218
+
219
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
220
+ attn = sim.softmax(dim=-1)
221
+
222
+ if exists(media_locations) and self.only_attend_immediate_media:
223
+ # any text without a preceding media needs to have attention zeroed out
224
+ text_without_media_mask = text_time == 0
225
+ text_without_media_mask = rearrange(
226
+ text_without_media_mask, "b i -> b 1 i 1"
227
+ )
228
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
229
+
230
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
231
+ out = rearrange(out, "b h n d -> b n (h d)")
232
+ return self.to_out(out)
233
+
234
+
235
+ class GatedCrossAttentionBlock(nn.Module):
236
+ def __init__(
237
+ self,
238
+ *,
239
+ dim,
240
+ dim_visual,
241
+ dim_head=64,
242
+ heads=8,
243
+ ff_mult=4,
244
+ only_attend_immediate_media=True,
245
+ ):
246
+ super().__init__()
247
+ self.attn = MaskedCrossAttention(
248
+ dim=dim,
249
+ dim_visual=dim_visual,
250
+ dim_head=dim_head,
251
+ heads=heads,
252
+ only_attend_immediate_media=only_attend_immediate_media,
253
+ )
254
+
255
+ def forward(
256
+ self,
257
+ x,
258
+ media,
259
+ media_locations=None,
260
+ attend_previous=True,
261
+ ):
262
+ x = self.attn(x, media, media_locations=media_locations, attend_previous=attend_previous) + x
263
+ return x
multimodal/build/lib/open_flamingo/src/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_instance(obj, mixin):
2
+ """Apply mixins to a class instance after creation"""
3
+ base_cls = obj.__class__
4
+ base_cls_name = obj.__class__.__name__
5
+ obj.__class__ = type(
6
+ base_cls_name, (mixin, base_cls), {}
7
+ ) # mixin needs to go first for our forward() logic to work
8
+
9
+
10
+ def getattr_recursive(obj, att):
11
+ """
12
+ Return nested attribute of obj
13
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14
+ """
15
+ if att == "":
16
+ return obj
17
+ i = att.find(".")
18
+ if i < 0:
19
+ return getattr(obj, att)
20
+ else:
21
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22
+
23
+
24
+ def setattr_recursive(obj, att, val):
25
+ """
26
+ Set nested attribute of obj
27
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28
+ """
29
+ if "." in att:
30
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31
+ setattr(obj, att.split(".")[-1], val)
multimodal/build/lib/open_flamingo/train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
multimodal/build/lib/open_flamingo/train/data2.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import math
4
+ import random
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from multiprocessing import Value
8
+ import time
9
+ import os
10
+ import numpy as np
11
+ import pickle as pkl
12
+ from open_flamingo.train.instruction_template import (
13
+ VG_RELATION_TEMPLATES,
14
+ PISC_TEMPLATES,
15
+ )
16
+
17
+ import torch
18
+ import webdataset as wds
19
+ from PIL import Image
20
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
21
+ from torch.utils.data.distributed import DistributedSampler
22
+ from webdataset.tariterators import (
23
+ base_plus_ext,
24
+ tar_file_expander,
25
+ url_opener,
26
+ valid_sample,
27
+ )
28
+
29
+ from groundingdino.demo.caption_grounder import caption_grounder
30
+ from groundingdino.demo.inference_on_laion import add_loc_to_text
31
+ from groundingdino.demo.inference_on_laion import nms_without_score
32
+ from groundingdino.demo.inference_on_laion import calculate_iou
33
+
34
+ Image.MAX_IMAGE_PIXELS = 1000000000
35
+ LAION2B_NUM_SAMPLE = 1500000000
36
+ VQAV2_TRAIN_NUM_SAMPLE = 1828467
37
+ VG_RELATION_BBOX_SIZE = 600
38
+
39
+ REL_LABELS = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']
40
+
41
+ try:
42
+ import horovod.torch as hvd
43
+ except ImportError:
44
+ hvd = None
45
+
46
+ class ConcatDataset(IterableDataset):
47
+ def __init__(
48
+ self, dataset, max_length,
49
+ delimiter_id, pad_id=None, media_id=None, endofmedia_id=None,
50
+ image_embedding_size=-2, single=False, box_id=None, visual_id=None,
51
+ ):
52
+ self.dataset = dataset
53
+ self.max_length = max_length
54
+ self.delimiter_id = torch.ones(1,1).long() * delimiter_id
55
+ if pad_id is not None:
56
+ self.pad_id = int(pad_id)
57
+ if media_id is not None:
58
+ self.media_id = torch.ones(1,1).long() * int(media_id)
59
+ if endofmedia_id is not None:
60
+ self.endofmedia_id = torch.ones(1,1).long() * int(endofmedia_id)
61
+ if image_embedding_size > 0:
62
+ logging.info(f"image_embedding_size: {image_embedding_size}")
63
+ self.image_embedding_size = image_embedding_size + 2
64
+ self.single = single
65
+ self.box_id = box_id
66
+ self.visual_id = visual_id
67
+
68
+ def __iter__(self):
69
+ while True:
70
+ input_ids_list = []
71
+ attention_mask_list = []
72
+ image_list = []
73
+ image_start_index_list = []
74
+ added_bbox_list = []
75
+ relations_list = []
76
+ cnt = 0
77
+ while cnt < self.max_length:
78
+ sample = next(self.dataset)
79
+ if len(sample) >= 4:
80
+ image = sample[0].unsqueeze(0)
81
+ input_ids = sample[1]
82
+ attention_mask = sample[2]
83
+ added_bbox = sample[3]
84
+ image_list.append(image)
85
+ added_bbox_list.append(added_bbox)
86
+ if len(sample) == 5:
87
+ relations_list.append(sample[4])
88
+ else:
89
+ sample = sample[0]
90
+ input_ids = sample[0]
91
+ attention_mask = sample[1]
92
+ input_ids_list.append(input_ids)
93
+ attention_mask_list.append(attention_mask)
94
+ cnt += input_ids.shape[-1]
95
+ if self.single:
96
+ break
97
+ input_ids = torch.cat(input_ids_list, dim=-1)[0]
98
+ attention_mask = torch.cat(attention_mask_list, dim=-1)[0]
99
+ if not self.single:
100
+ input_ids = input_ids[:self.max_length]
101
+ attention_mask = attention_mask[:self.max_length]
102
+ # TODO: fix visual number not match
103
+ if len(image_list) != 0:
104
+ images = torch.cat(image_list, dim=0)
105
+ image_begin = (input_ids == self.media_id[0,0]).nonzero().view(-1)
106
+ image_end = (input_ids == self.endofmedia_id[0,0]).nonzero().view(-1)
107
+ if len(image_begin) != len(image_end):
108
+ assert len(image_begin) == len(image_end) + 1
109
+ input_ids[image_begin[-1]:] = self.pad_id
110
+ attention_mask[image_begin[-1]:] = 0
111
+ image_begin = image_begin[:-1]
112
+ eos_token_num = len((input_ids == self.delimiter_id[0,0]).nonzero().view(-1))
113
+ if eos_token_num != len(image_begin) + 1:
114
+ input_ids[image_begin[-1]:] = self.pad_id
115
+ attention_mask[image_begin[-1]:] = 0
116
+ image_begin = image_begin[:-1]
117
+ image_end = image_end[:-1]
118
+ images = images[:len(image_end)]
119
+ added_bbox_list = added_bbox_list[:len(image_end)]
120
+ relations_list = relations_list[:len(image_end)]
121
+ image_start_index_list = (image_begin + 1).tolist()
122
+ expand_list = added_bbox_list[0]
123
+ for x in added_bbox_list[1:]:
124
+ expand_list.extend(x)
125
+ yield images, len(images), image_start_index_list, input_ids, attention_mask, expand_list, relations_list
126
+ else:
127
+ yield input_ids, attention_mask
128
+
129
+
130
+ class SharedEpoch:
131
+ def __init__(self, epoch: int = 0):
132
+ self.shared_epoch = Value("i", epoch)
133
+
134
+ def set_value(self, epoch):
135
+ self.shared_epoch.value = epoch
136
+
137
+ def get_value(self):
138
+ return self.shared_epoch.value
139
+
140
+
141
+ @dataclass
142
+ class DataInfo:
143
+ dataloader: DataLoader
144
+ sampler: DistributedSampler = None
145
+ shared_epoch: SharedEpoch = None
146
+
147
+ def set_epoch(self, epoch):
148
+ if self.shared_epoch is not None:
149
+ self.shared_epoch.set_value(epoch)
150
+ if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
151
+ self.sampler.set_epoch(epoch)
152
+
153
+
154
+ def filter_no_caption_or_no_image(sample):
155
+ return ("txt" in sample) and (
156
+ "png" in sample or "jpg" in sample or "jpeg" in sample
157
+ )
158
+
159
+
160
+ def log_and_continue(exn):
161
+ """Call in an exception handler to ignore any exception, issue a warning, and continue."""
162
+ if "ValueError" in repr(exn) or "KeyError" in repr(exn): # Avoid spamming logs with these
163
+ return True
164
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
165
+ return True
166
+ # DEBUG
167
+ # log_and_continue = None
168
+ # DEBUG
169
+
170
+
171
+ def group_by_keys_nothrow(
172
+ data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
173
+ ):
174
+ """Return function over iterator that groups key, value pairs into samples.
175
+
176
+ :param keys: function that splits the key into key and extension (base_plus_ext)
177
+ :param lcase: convert suffixes to lower case (Default value = True)
178
+ """
179
+ current_sample = None
180
+ tar_idx = None
181
+ for filesample in data:
182
+ assert isinstance(filesample, dict)
183
+ current_tar_idx = filesample["__url__"].split("/")[-1].split(".")[0]
184
+ if current_tar_idx != tar_idx:
185
+ tar_idx = current_tar_idx
186
+ if "blip2_all_data_ground" in filesample["__url__"]:
187
+ relation_data_dir = os.path.join("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_relation", tar_idx)
188
+ missing_file = False
189
+ try:
190
+ data_info = pkl.load(open(os.path.join(relation_data_dir, "custom_data_info.pkl"), "rb"))
191
+ prediction = pkl.load(open(os.path.join(relation_data_dir, "custom_prediction.pkl"), "rb"))
192
+ idx_to_files = data_info["idx_to_files"]
193
+ ind_to_classes = data_info["ind_to_classes"]
194
+ ind_to_predicates = data_info["ind_to_predicates"]
195
+ files_to_idx = {x.split("#")[-1]: i for i, x in enumerate(idx_to_files)}
196
+ except:
197
+ missing_file = True
198
+ fname, value = filesample["fname"], filesample["data"]
199
+ prefix, suffix = keys(fname)
200
+ if prefix is None:
201
+ continue
202
+ if lcase:
203
+ suffix = suffix.lower()
204
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
205
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
206
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
207
+ if (
208
+ current_sample is None
209
+ or prefix != current_sample["__key__"]
210
+ or suffix in current_sample
211
+ ):
212
+ if valid_sample(current_sample):
213
+ yield current_sample
214
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
215
+ if "blip2_all_data_ground" in filesample["__url__"] and not missing_file:
216
+ try:
217
+ idx = files_to_idx[prefix]
218
+ prediction[idx]["bbox"] = [np.array(bbox)/VG_RELATION_BBOX_SIZE for bbox in prediction[idx]["bbox"]]
219
+ current_sample["relation_data"] = prediction[idx]
220
+ except:
221
+ current_sample["relation_data"] = dict()
222
+ else:
223
+ current_sample["relation_data"] = dict()
224
+ if suffixes is None or suffix in suffixes:
225
+ current_sample[suffix] = value
226
+ if valid_sample(current_sample):
227
+ yield current_sample
228
+
229
+
230
+ def tarfile_to_samples_nothrow(src, handler=log_and_continue):
231
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
232
+ streams = url_opener(src, handler=handler)
233
+ files = tar_file_expander(streams, handler=handler)
234
+ samples = group_by_keys_nothrow(files, handler=handler)
235
+ return samples
236
+
237
+
238
+ def pytorch_worker_seed(increment=0):
239
+ """get dataloader worker seed from pytorch"""
240
+ worker_info = get_worker_info()
241
+ if worker_info is not None:
242
+ # favour using the seed already created for pytorch dataloader workers if it exists
243
+ seed = worker_info.seed
244
+ if increment:
245
+ # space out seed increments so they can't overlap across workers in different iterations
246
+ seed += increment * max(1, worker_info.num_workers)
247
+ return seed
248
+ # fallback to wds rank based seed
249
+ return wds.utils.pytorch_worker_seed()
250
+
251
+
252
+ _SHARD_SHUFFLE_SIZE = 2000
253
+ _SHARD_SHUFFLE_INITIAL = 500
254
+ _SAMPLE_SHUFFLE_SIZE = 5000
255
+ _SAMPLE_SHUFFLE_INITIAL = 1000
256
+
257
+
258
+ class ResampledShards2(IterableDataset):
259
+ """An iterable dataset yielding a list of urls."""
260
+
261
+ def __init__(
262
+ self,
263
+ urls,
264
+ nshards=sys.maxsize,
265
+ worker_seed=None,
266
+ deterministic=False,
267
+ epoch=-1,
268
+ ):
269
+ """Sample shards from the shard list with replacement.
270
+ :param urls: a list of URLs as a Python list or brace notation string
271
+ """
272
+ super().__init__()
273
+ urls = wds.shardlists.expand_urls(urls)
274
+ self.urls = urls
275
+ assert isinstance(self.urls[0], str)
276
+ self.nshards = nshards
277
+ self.rng = random.Random()
278
+ self.worker_seed = worker_seed
279
+ self.deterministic = deterministic
280
+ self.epoch = epoch
281
+
282
+ def __iter__(self):
283
+ """Return an iterator over the shards."""
284
+ if isinstance(self.epoch, SharedEpoch):
285
+ epoch = self.epoch.get_value()
286
+ else:
287
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
288
+ # situation as different workers may wrap at different times (or not at all).
289
+ self.epoch += 1
290
+ epoch = self.epoch
291
+
292
+ if self.deterministic:
293
+ # reset seed w/ epoch if deterministic
294
+ if self.worker_seed is None:
295
+ # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
296
+ seed = pytorch_worker_seed(epoch)
297
+ else:
298
+ seed = self.worker_seed() + epoch
299
+ seed = seed + int(time.time())
300
+ self.rng.seed(seed)
301
+ # logging.info(f"epoch: {epoch} seed: {seed}")
302
+ self.rng.shuffle(self.urls)
303
+ # logging.info(f"{len(self.urls)} | {self.urls[:2]}")
304
+ for url in self.urls:
305
+ # logging.info(f"{seed}: {url}")
306
+ yield dict(url=url)
307
+
308
+
309
+ def preprocess_image(sample, image_processor):
310
+ image = image_processor(sample)
311
+ return image
312
+
313
+
314
+ def preprocess_text(sample, tokenizer, max_length, single=False):
315
+ if not single:
316
+ text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True)
317
+ else:
318
+ text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True, padding='max_length')
319
+ return text["input_ids"], text["attention_mask"]
320
+
321
+
322
+ def preprocess_encoded_text(sample, tokenizer, max_length):
323
+ sample = sample.decode("utf-8")
324
+ return preprocess_text(sample, tokenizer, max_length=max_length)
325
+
326
+
327
+ def _merge_bbox_previsual(added_bbox_list):
328
+ bbox_list = []
329
+ for bboxes in added_bbox_list:
330
+ x1 = bboxes[:, 0].min()
331
+ y1 = bboxes[:, 1].min()
332
+ x2 = bboxes[:, 2].max()
333
+ y2 = bboxes[:, 3].max()
334
+ bbox_list.append(torch.tensor([x1, y1, x2, y2], device=bboxes.device, dtype=bboxes.dtype).unsqueeze(0))
335
+ return bbox_list
336
+
337
+
338
+ def _find_idx(text, subtext):
339
+ loc = 0
340
+ locs = []
341
+ while text.find(subtext, loc) != -1:
342
+ loc = text.find(subtext, loc)
343
+ locs.append(loc)
344
+ loc += len(subtext)
345
+ return locs
346
+
347
+ def preprocess_ground_caption(sample, image_processor, tokenizer, image_embedding_size, generator, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None, args=None):
348
+ assert max_length is not None
349
+ assert not single, "single is not supported for preprocess_ground_caption"
350
+ image, caption, logits_filt, boxes_filt, relation_data = sample
351
+ if len(logits_filt.shape) == 1 and logits_filt.shape[0] == 4 and len(boxes_filt.shape) == 1 and boxes_filt.shape[0] == 4:
352
+ raise NotImplementedError # lack relation data
353
+ return preprocess_visual_genome(sample=sample, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=image_embedding_size, prob_ground=prob_ground, single=single, use_format_v2=use_format_v2, add_visual_token=add_visual_token, max_length=max_length)
354
+ image = preprocess_image(image, image_processor=image_processor)
355
+ added_bbox = []
356
+ if (prob_ground != 0 and random.random() <= prob_ground) or prob_ground == 1.0:
357
+ boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
358
+ caption, added_bbox = add_loc_to_text(
359
+ boxes_filt, pred_phrases, caption,
360
+ expand=args.expand, always_expand=args.longer_previsual,
361
+ )
362
+ visual_loc = []
363
+ obj_loc = []
364
+ endofobj_loc = []
365
+ visual_token = "<|#visual#|>"
366
+ previsual_token = "<|#previsual#|>"
367
+ box_token = "<|#box#|>"
368
+ prebox_token = "<|#prebox#|>"
369
+ end_token = "<|#endofobject#|>"
370
+ object_token = "<|#object#|>"
371
+ end_of_attr_token = "<|#endofattr#|>"
372
+ preend_of_attr_token = "<|#preendofattr#|>"
373
+ visual_loc = _find_idx(caption, visual_token)
374
+ try:
375
+ if len(visual_loc) != len(added_bbox):
376
+ logging.warning(f"visual_loc: {visual_loc}")
377
+ logging.warning(f"added_bbox: {added_bbox}")
378
+ except:
379
+ pass
380
+ assert len(visual_loc) == len(added_bbox)
381
+ delta = 0
382
+ for i, (loc, boxes) in enumerate(zip(visual_loc, added_bbox)):
383
+ loc += delta
384
+ boxes = nms_without_score(boxes)
385
+ added_bbox[i] = boxes
386
+ added_tokens = end_token + visual_token + box_token * len(boxes) + end_of_attr_token
387
+ caption = caption[:loc] + added_tokens + caption[len(visual_token) + loc:]
388
+ delta += len(added_tokens) - len(visual_token)
389
+
390
+ if use_format_v2:
391
+ merge_added_bbox = _merge_bbox_previsual(added_bbox)
392
+ # step 1: move <|#object#|> before the space char
393
+ while caption.find(f" {object_token}") != -1:
394
+ caption = caption.replace(f" {object_token}", f"{object_token} ")
395
+ # step 2: add <|#previsual#|> after <|#object#|> for 75% except the first object
396
+ i = 0
397
+ II = -1
398
+ if args.no_visual:
399
+ flag = False
400
+ delete_visual_prob = 10.0
401
+ else:
402
+ flag = True
403
+ delete_visual_prob = 0.75
404
+ while i < len(caption):
405
+ if caption[i: i + len(object_token)] == object_token:
406
+ II += 1
407
+ if (not args.longer_previsual and not flag and random.random() < delete_visual_prob) or (args.longer_previsual and (flag or random.random() < delete_visual_prob)):
408
+ # delete visual and add previsual
409
+ visual_start_idx = caption.find(end_token, i+1) + len(end_token)
410
+ visual_end_idx = caption.find(end_of_attr_token, visual_start_idx+1) + len(end_of_attr_token)
411
+ caption = caption[:visual_start_idx] + caption[visual_end_idx:]
412
+ caption = caption[:i + len(object_token)] + previsual_token + prebox_token + preend_of_attr_token + caption[i + len(object_token):]
413
+ added_bbox[II] = merge_added_bbox[II]
414
+ i += 1
415
+ flag = False
416
+ if args.no_previsual and args.no_visual:
417
+ caption = caption.replace(previsual_token, "").replace(prebox_token, "").replace(preend_of_attr_token, "")
418
+ added_bbox = []
419
+ caption = caption.replace(preend_of_attr_token, object_token).replace(end_of_attr_token, end_token)
420
+
421
+
422
+ if args.roi_align:
423
+ i = 0
424
+ pad_num = args.roi_output_size ** 2 - 1
425
+ while i < len(caption):
426
+ if caption[i: i + len(prebox_token)] == prebox_token:
427
+ caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:]
428
+ i += len(tokenizer.pad_token) * pad_num + len(prebox_token)
429
+ elif caption[i: i + len(box_token)] == box_token:
430
+ caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:]
431
+ i += len(tokenizer.pad_token) * pad_num + len(box_token)
432
+ i += 1
433
+
434
+ caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
435
+ input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
436
+ relations = []
437
+ if args.only_grounded_sample and "<|#visual#|>" not in caption:
438
+ raise ValueError
439
+ return image, input_ids, attention_mask, added_bbox, relations
440
+
441
+
442
+ def preprocess_visual_genome(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
443
+ assert max_length is not None
444
+ assert not single, "single is not supported for preprocess_ground_caption"
445
+ image, caption, xyxy, _ = sample
446
+ image = preprocess_image(image, image_processor=image_processor)
447
+ caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|><|#object#|>" + caption.strip() + "<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
448
+ input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
449
+ added_bbox = [torch.tensor(np.expand_dims(xyxy, 0).astype(np.float32) / 224)]
450
+ return image, input_ids, attention_mask, added_bbox
451
+
452
+ special_predicate = [
453
+ "and",
454
+ "has",
455
+ "says",
456
+ "wears",
457
+ ]
458
+
459
+ original_predicate = {
460
+ "and": "and",
461
+ "has": "have",
462
+ "says": "say",
463
+ "wears": "wear",
464
+ }
465
+
466
+
467
+ def generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation):
468
+ if relation in ["and", "of"]:
469
+ id = 0
470
+ else:
471
+ id = random.choice(range(len(VG_RELATION_TEMPLATES)))
472
+ text = VG_RELATION_TEMPLATES[id].format(nameA=nameA, nameB=nameB, relation=relation, use_is="is" if relation not in special_predicate else "", is_or_does="is" if relation not in special_predicate else "does", relation_do=relation if relation not in special_predicate else original_predicate[relation])
473
+ if id in [0]:
474
+ added_bbox = [
475
+ torch.tensor([boxA]),
476
+ torch.tensor([boxB]),
477
+ ]
478
+ elif id in [1]:
479
+ added_bbox = [
480
+ torch.tensor([boxA]),
481
+ torch.tensor([boxB]),
482
+ torch.tensor([boxA]),
483
+ torch.tensor([boxB]),
484
+ ]
485
+ elif id in [2]:
486
+ added_bbox = [
487
+ torch.tensor([boxA]),
488
+ torch.tensor([boxA]),
489
+ torch.tensor([boxB]),
490
+ ]
491
+ elif id in [3]:
492
+ added_bbox = [
493
+ torch.tensor([boxB]),
494
+ torch.tensor([boxA]),
495
+ torch.tensor([boxB]),
496
+ ]
497
+ elif id in [4]:
498
+ added_bbox = [
499
+ torch.tensor([boxA]),
500
+ torch.tensor([boxB]),
501
+ ]
502
+ elif id in [5]:
503
+ added_bbox = [
504
+ torch.tensor([boxB]),
505
+ torch.tensor([boxA]),
506
+ ]
507
+ else:
508
+ raise NotImplementedError
509
+ return text, added_bbox
510
+
511
+ def generate_pisc_sample(boxA, boxB, relation):
512
+ id = random.choice(range(len(PISC_TEMPLATES)))
513
+ text = PISC_TEMPLATES[id].format(relation=relation)
514
+ if id in [0]:
515
+ if random.random() < 0.5:
516
+ added_bbox = [
517
+ torch.tensor([boxA]),
518
+ torch.tensor([boxB]),
519
+ ]
520
+ else:
521
+ added_bbox = [
522
+ torch.tensor([boxB]),
523
+ torch.tensor([boxA]),
524
+ ]
525
+ elif id in [1]:
526
+ if random.random() < 0.5:
527
+ added_bbox = [torch.tensor([boxA, boxB])]
528
+ else:
529
+ added_bbox = [torch.tensor([boxB, boxA])]
530
+ return text, added_bbox
531
+
532
+
533
+ def preprocess_instruct(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
534
+ image_path, dataset, data = sample
535
+ image = Image.open(image_path)
536
+ size = image_processor.transforms[0].size
537
+ image = image.resize((size, size))
538
+ if dataset == "pisc_relation_split":
539
+ boxA = data[0]
540
+ boxB = data[1]
541
+ relation = data[2]
542
+ text, added_bbox = generate_pisc_sample(boxA, boxB, relation)
543
+ # import cv2
544
+ # boxA *= size
545
+ # boxB *= size
546
+ # open_cv_image = np.array(image)
547
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
548
+ # open_cv_image = cv2.rectangle(open_cv_image, boxA[:2].astype(int), boxA[2:].astype(int), (255, 0, 0), 2)
549
+ # open_cv_image = cv2.rectangle(open_cv_image, boxB[:2].astype(int), boxB[2:].astype(int), (0, 255, 0), 2)
550
+ # cv2.imwrite("output.jpg", open_cv_image)
551
+ # import pdb; pdb.set_trace()
552
+ elif dataset == "vg_relation":
553
+ boxA = data[0][0]
554
+ nameA = data[0][1]
555
+ boxB = data[1][0]
556
+ nameB = data[1][1]
557
+ relation = data[2]
558
+ text, added_bbox = generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation)
559
+ image = preprocess_image(image, image_processor=image_processor)
560
+ caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + text + tokenizer.eos_token
561
+ input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=True)
562
+ # return image, input_ids, attention_mask, added_bbox
563
+ images = image.unsqueeze(0)
564
+ image_start_index_list = [2]
565
+ return images, len(images), image_start_index_list, input_ids, attention_mask, added_bbox
566
+
567
+
568
+ def preprocess_caption(sample, image_processor, tokenizer, image_embedding_size, max_length, single=False):
569
+ image, caption = sample
570
+ caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
571
+ image = preprocess_image(image, image_processor=image_processor)
572
+ input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=single)
573
+ return image, input_ids, attention_mask
574
+
575
+
576
+ def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
577
+ input_shards = args.pile_shards
578
+ assert input_shards is not None
579
+ resampled = getattr(args, "dataset_resampled", False)
580
+ assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
581
+
582
+ # create a shared epoch store to sync epoch to dataloader worker proc
583
+ shared_epoch = SharedEpoch(epoch=epoch)
584
+ preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer, max_length=args.max_length)
585
+ pipeline = [
586
+ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
587
+ tarfile_to_samples_nothrow,
588
+ wds.shuffle(
589
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
590
+ initial=_SAMPLE_SHUFFLE_INITIAL,
591
+ ),
592
+ wds.to_tuple("txt", handler=log_and_continue),
593
+ wds.map_tuple(
594
+ preprocess_text_fn, handler=log_and_continue
595
+ ),
596
+ ]
597
+ # with_epoch(sys.maxsize) will give us an infinite sample stream
598
+ dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
599
+ delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
600
+ dataset = ConcatDataset(iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id)
601
+
602
+
603
+ def text_collate_fn(items):
604
+ try:
605
+ input_ids = torch.cat([x[0].unsqueeze(0) for x in items], dim=0)
606
+ attention_mask = torch.cat([x[1].unsqueeze(0) for x in items], dim=0)
607
+ return input_ids, attention_mask
608
+ except:
609
+ return None, None
610
+
611
+ dataloader = wds.WebLoader(
612
+ dataset,
613
+ batch_size=args.batch_size_pile,
614
+ shuffle=False,
615
+ num_workers=args.workers,
616
+ persistent_workers=False,
617
+ collate_fn=text_collate_fn,
618
+ )
619
+ return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
620
+
621
+
622
+ # FIXME:
623
+ # modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433
624
+ # combine_tensors=True to combine_tensors=False
625
+ def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
626
+ input_shards = args.laion_shards
627
+ assert input_shards is not None
628
+ resampled = getattr(args, "dataset_resampled", False)
629
+ assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
630
+ # create a shared epoch store to sync epoch to dataloader worker proc
631
+ shared_epoch = SharedEpoch(epoch=epoch)
632
+ generator = caption_grounder(
633
+ config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
634
+ checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
635
+ cpu_only=True,
636
+ # box_threshold=0.5, text_threshold=0.3,
637
+ )
638
+ preprocess_ground_caption_fn = functools.partial(
639
+ preprocess_ground_caption, image_processor=image_processor, tokenizer=tokenizer,
640
+ image_embedding_size=args.vis_embed_size, single=args.single, generator=generator,
641
+ prob_ground=args.prob_ground, use_format_v2=args.use_format_v2,
642
+ add_visual_token=args.add_visual_token, max_length=args.max_length,
643
+ args=args,
644
+ )
645
+ pipeline = [
646
+ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
647
+ tarfile_to_samples_nothrow,
648
+ wds.shuffle(
649
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
650
+ initial=_SAMPLE_SHUFFLE_INITIAL,
651
+ ),
652
+ wds.select(filter_no_caption_or_no_image),
653
+ wds.decode("pilrgb", partial=True, handler=log_and_continue),
654
+ wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", "relation_data", handler=log_and_continue),
655
+ wds.map(
656
+ preprocess_ground_caption_fn, handler=log_and_continue
657
+ ),
658
+ ]
659
+
660
+ dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
661
+ # for sample in dataset:
662
+ # print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
663
+ # DEBUG
664
+ # dataset = wds.DataPipeline(*pipeline)
665
+ # from tqdm import tqdm
666
+ # for sample in tqdm(dataset):
667
+ # nn = 0
668
+ # for x in sample[1][0]:
669
+ # if x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]:
670
+ # nn += 1
671
+ # if x == tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]:
672
+ # nn -= 1
673
+ # if nn not in [0, 1]:
674
+ # print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
675
+ # import pdb; pdb.set_trace()
676
+ # if nn != 0:
677
+ # print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
678
+ # import pdb; pdb.set_trace()
679
+ # from groundingdino.demo.inference_on_laion import OBJ_LENGTHS
680
+ # # import pdb; pdb.set_trace()
681
+ # print(sum(OBJ_LENGTHS) / len(OBJ_LENGTHS))
682
+ # exit()
683
+ # DEBUG
684
+
685
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
686
+ delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
687
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
688
+ box_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
689
+ visual_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
690
+ dataset = ConcatDataset(
691
+ iter(dataset), max_length=args.max_length,
692
+ delimiter_id=delimiter_id,
693
+ pad_id=tokenizer.pad_token_id,
694
+ media_id=media_token_id,
695
+ endofmedia_id=endofmedia_token_id,
696
+ box_id=box_id,
697
+ visual_id=visual_id,
698
+ image_embedding_size=args.vis_embed_size,
699
+ single=args.single,
700
+ )
701
+
702
+ def image_collate_fn(items):
703
+ images = torch.cat([x[0] for x in items], dim=0)
704
+ image_nums = [x[1] for x in items]
705
+ image_start_index_list = [x[2] for x in items]
706
+ input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
707
+ attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
708
+ added_bbox_list = [x[5] for x in items]
709
+ expand_list = added_bbox_list[0]
710
+ for x in added_bbox_list[1:]:
711
+ expand_list.extend(x)
712
+ relations_list = [x[6] for x in items]
713
+ return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list, relations_list
714
+
715
+ dataloader = wds.WebLoader(
716
+ dataset,
717
+ batch_size=args.batch_size_laion,
718
+ shuffle=False,
719
+ num_workers=args.workers,
720
+ persistent_workers=False,
721
+ collate_fn=image_collate_fn,
722
+ )
723
+ round_fn = math.floor if floor else math.ceil
724
+ global_batch_size = args.batch_size_laion * args.world_size
725
+ num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
726
+ dataloader.num_batches = num_batches
727
+ return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
728
+
729
+
730
+ def get_image_text_pair_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
731
+ input_shards = args.laion_shards
732
+ assert input_shards is not None
733
+ resampled = getattr(args, "dataset_resampled", False)
734
+ assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
735
+ # create a shared epoch store to sync epoch to dataloader worker proc
736
+ shared_epoch = SharedEpoch(epoch=epoch)
737
+ preprocess_caption_fn = functools.partial(
738
+ preprocess_caption, image_processor=image_processor, tokenizer=tokenizer,
739
+ image_embedding_size=args.vis_embed_size, single=args.single,
740
+ max_length=args.max_length,
741
+ )
742
+ pipeline = [
743
+ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
744
+ tarfile_to_samples_nothrow,
745
+ wds.shuffle(
746
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
747
+ initial=_SAMPLE_SHUFFLE_INITIAL,
748
+ ),
749
+ wds.select(filter_no_caption_or_no_image),
750
+ wds.decode("pilrgb", handler=log_and_continue),
751
+ wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
752
+ wds.map(
753
+ preprocess_caption_fn, handler=log_and_continue
754
+ ),
755
+ ]
756
+
757
+ dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
758
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
759
+ delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
760
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
761
+ dataset = ConcatDataset(
762
+ iter(dataset), max_length=args.max_length,
763
+ delimiter_id=delimiter_id,
764
+ pad_id=tokenizer.pad_token_id,
765
+ media_id=media_token_id,
766
+ endofmedia_id=endofmedia_token_id,
767
+ image_embedding_size=args.vis_embed_size,
768
+ single=args.single,
769
+ )
770
+
771
+ def image_collate_fn(items):
772
+ images = torch.cat([x[0] for x in items], dim=0)
773
+ image_nums = [x[1] for x in items]
774
+ image_start_index_list = [x[2] for x in items]
775
+ input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
776
+ attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
777
+ return images, image_nums, image_start_index_list, input_ids, attention_mask
778
+
779
+ dataloader = wds.WebLoader(
780
+ dataset,
781
+ batch_size=args.batch_size_laion,
782
+ shuffle=False,
783
+ num_workers=args.workers,
784
+ persistent_workers=False,
785
+ collate_fn=image_collate_fn,
786
+ )
787
+ round_fn = math.floor if floor else math.ceil
788
+ global_batch_size = args.batch_size_laion * args.world_size
789
+ num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
790
+ dataloader.num_batches = num_batches
791
+ return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
792
+
793
+
794
+ def get_instruct_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
795
+ input_shards = args.laion_shards
796
+ assert input_shards is not None
797
+ resampled = getattr(args, "dataset_resampled", False)
798
+ assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
799
+ # create a shared epoch store to sync epoch to dataloader worker proc
800
+ shared_epoch = SharedEpoch(epoch=epoch)
801
+ preprocess_instruct_fn = functools.partial(
802
+ preprocess_instruct, image_processor=image_processor, tokenizer=tokenizer,
803
+ image_embedding_size=args.vis_embed_size,
804
+ max_length=args.max_length,
805
+ )
806
+ pipeline = [
807
+ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
808
+ tarfile_to_samples_nothrow,
809
+ wds.shuffle(
810
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
811
+ initial=_SAMPLE_SHUFFLE_INITIAL,
812
+ ),
813
+ wds.decode(partial=True),
814
+ wds.to_tuple("image_path.txt", "dataset.txt", "data.pyd", handler=log_and_continue),
815
+ wds.map(
816
+ preprocess_instruct_fn, handler=log_and_continue
817
+ ),
818
+ ]
819
+ dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
820
+
821
+ def image_collate_fn(items):
822
+ images = torch.cat([x[0] for x in items], dim=0)
823
+ image_nums = [x[1] for x in items]
824
+ image_start_index_list = [x[2] for x in items]
825
+ input_ids = torch.cat([x[3] for x in items], dim=0)
826
+ attention_mask = torch.cat([x[4] for x in items], dim=0)
827
+ added_bbox_list = [x[5] for x in items]
828
+ expand_list = added_bbox_list[0]
829
+ for x in added_bbox_list[1:]:
830
+ expand_list.extend(x)
831
+ return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list
832
+
833
+ dataloader = wds.WebLoader(
834
+ dataset,
835
+ batch_size=args.batch_size_laion,
836
+ shuffle=False,
837
+ num_workers=args.workers,
838
+ persistent_workers=False,
839
+ collate_fn=image_collate_fn,
840
+ )
841
+ round_fn = math.floor if floor else math.ceil
842
+ global_batch_size = args.batch_size_laion * args.world_size
843
+ num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
844
+ dataloader.num_batches = num_batches
845
+ return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
846
+
847
+
848
+ def get_dataset_fn(dataset_type):
849
+ if dataset_type == "mmc4":
850
+ raise NotImplementedError
851
+ elif dataset_type == "pile":
852
+ return get_pile_dataset
853
+ elif dataset_type == "ground_image_text":
854
+ return get_ground_laion_dataset
855
+ elif dataset_type == "image_text":
856
+ return get_image_text_pair_dataset
857
+ elif dataset_type == "vqav2":
858
+ raise NotImplementedError
859
+ elif dataset_type == "instruct":
860
+ return get_instruct_dataset
861
+ else:
862
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
863
+
864
+
865
+ def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
866
+ return get_dataset_fn(dataset_type)(
867
+ args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
868
+ )
multimodal/build/lib/open_flamingo/train/distributed.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ try:
6
+ import horovod.torch as hvd
7
+ except ImportError:
8
+ hvd = None
9
+
10
+
11
+ def is_global_master(args):
12
+ return args.rank == 0
13
+
14
+
15
+ def is_local_master(args):
16
+ return args.local_rank == 0
17
+
18
+
19
+ def is_master(args, local=False):
20
+ return is_local_master(args) if local else is_global_master(args)
21
+
22
+
23
+ def is_using_horovod():
24
+ # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
25
+ # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
26
+ ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
27
+ pmi_vars = ["PMI_RANK", "PMI_SIZE"]
28
+ if all([var in os.environ for var in ompi_vars]) or all(
29
+ [var in os.environ for var in pmi_vars]
30
+ ):
31
+ return True
32
+ else:
33
+ return False
34
+
35
+
36
+ def is_using_distributed():
37
+ if "WORLD_SIZE" in os.environ:
38
+ return int(os.environ["WORLD_SIZE"]) > 1
39
+ if "SLURM_NTASKS" in os.environ:
40
+ return int(os.environ["SLURM_NTASKS"]) > 1
41
+ return False
42
+
43
+
44
+ def world_info_from_env():
45
+ local_rank = 0
46
+ for v in (
47
+ "LOCAL_RANK",
48
+ "MPI_LOCALRANKID",
49
+ "SLURM_LOCALID",
50
+ "OMPI_COMM_WORLD_LOCAL_RANK",
51
+ ):
52
+ if v in os.environ:
53
+ local_rank = int(os.environ[v])
54
+ break
55
+ global_rank = 0
56
+ for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
57
+ if v in os.environ:
58
+ global_rank = int(os.environ[v])
59
+ break
60
+ world_size = 1
61
+ for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
62
+ if v in os.environ:
63
+ world_size = int(os.environ[v])
64
+ break
65
+
66
+ return local_rank, global_rank, world_size
67
+
68
+
69
+ def init_distributed_device(args):
70
+ # Distributed training = training on more than one GPU.
71
+ # Works in both single and multi-node scenarios.
72
+ args.distributed = False
73
+ args.world_size = 1
74
+ args.rank = 0 # global rank
75
+ args.local_rank = 0
76
+ if args.horovod:
77
+ assert hvd is not None, "Horovod is not installed"
78
+ hvd.init()
79
+ args.local_rank = int(hvd.local_rank())
80
+ args.rank = hvd.rank()
81
+ args.world_size = hvd.size()
82
+ args.distributed = True
83
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
84
+ os.environ["RANK"] = str(args.rank)
85
+ os.environ["WORLD_SIZE"] = str(args.world_size)
86
+ elif is_using_distributed():
87
+ if "SLURM_PROCID" in os.environ:
88
+ # DDP via SLURM
89
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
90
+ # SLURM var -> torch.distributed vars in case needed
91
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
92
+ os.environ["RANK"] = str(args.rank)
93
+ os.environ["WORLD_SIZE"] = str(args.world_size)
94
+ torch.distributed.init_process_group(
95
+ backend=args.dist_backend,
96
+ init_method=args.dist_url,
97
+ world_size=args.world_size,
98
+ rank=args.rank,
99
+ )
100
+ else:
101
+ # DDP via torchrun, torch.distributed.launch
102
+ args.local_rank, _, _ = world_info_from_env()
103
+ torch.distributed.init_process_group(
104
+ backend=args.dist_backend, init_method=args.dist_url
105
+ )
106
+ args.world_size = torch.distributed.get_world_size()
107
+ args.rank = torch.distributed.get_rank()
108
+ args.distributed = True
109
+ else:
110
+ # needed to run on single gpu
111
+ torch.distributed.init_process_group(
112
+ backend=args.dist_backend,
113
+ init_method=args.dist_url,
114
+ world_size=1,
115
+ rank=0,
116
+ )
117
+
118
+ if torch.cuda.is_available():
119
+ if args.distributed and not args.no_set_device_rank:
120
+ device = "cuda:%d" % args.local_rank
121
+ else:
122
+ device = "cuda:0"
123
+ torch.cuda.set_device(device)
124
+ else:
125
+ device = "cpu"
126
+ args.device = device
127
+ device = torch.device(device)
128
+ return device
multimodal/build/lib/open_flamingo/train/instruction_template.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VG_RELATION_TEMPLATES = [
2
+ "Question: What is the relationship between<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
3
+ "Question: What is the relationship between<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
4
+ "Question: What {is_or_does}<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {relation_do}? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|>{nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
5
+ "Question: What {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
6
+ "Question: What {is_or_does}<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {relation_do}? Answer:<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
7
+ "Question: What {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
8
+ ]
9
+
10
+ PISC_TEMPLATES = [
11
+ "Question: What is the social relationship between this<|#object#|> person<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and that<|#object#|> person<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
12
+ "Question: What is the social relationship between these<|#object#|> people<|#endofobject#|><|#visual#|><|#box#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
13
+ ]
multimodal/build/lib/open_flamingo/train/train.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Main training script """
2
+
3
+ import argparse
4
+ import copy
5
+ import glob
6
+ import os
7
+ import random
8
+ import functools
9
+
10
+ import numpy as np
11
+ import torch
12
+ # torch.multiprocessing.set_sharing_strategy('file_system')
13
+ import wandb
14
+ from data2 import get_data
15
+ from distributed import init_distributed_device, world_info_from_env
16
+ from torch.distributed.fsdp import (
17
+ FullyShardedDataParallel as FSDP,
18
+ MixedPrecision,
19
+ BackwardPrefetch,
20
+ ShardingStrategy,
21
+ FullStateDictConfig,
22
+ CPUOffload,
23
+ StateDictType,
24
+ )
25
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
26
+ from torch.distributed.fsdp.wrap import (
27
+ transformer_auto_wrap_policy,
28
+ enable_wrap,
29
+ wrap,
30
+ )
31
+
32
+ from train_utils import train_one_epoch
33
+ from transformers import (
34
+ get_constant_schedule_with_warmup,
35
+ get_cosine_schedule_with_warmup,
36
+ get_linear_schedule_with_warmup,
37
+ )
38
+
39
+ from open_flamingo import create_model_and_transforms
40
+ from torch.utils.tensorboard import SummaryWriter
41
+ from torch.nn.parallel import DistributedDataParallel as DDP
42
+ from torch.cuda.amp import GradScaler
43
+ from torch.distributed.optim import ZeroRedundancyOptimizer
44
+ import warnings
45
+ warnings.filterwarnings("ignore")
46
+ import logging
47
+ logging.basicConfig(
48
+ level=logging.INFO,
49
+ format='%(asctime)s %(message)s',
50
+ datefmt='%m/%d %I:%M:%S',
51
+ )
52
+
53
+ class FakeDataloader:
54
+ def __iter__(self):
55
+ return self
56
+
57
+ def __next__(self):
58
+ return None
59
+
60
+ def random_seed(seed=42, rank=0):
61
+ torch.manual_seed(seed + rank)
62
+ np.random.seed(seed + rank)
63
+ random.seed(seed + rank)
64
+
65
+
66
+ def get_grouped_params(model, args):
67
+ params_with_wd, params_without_wd = [], []
68
+
69
+ def apply_decay(x):
70
+ x = x.lower()
71
+ return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x
72
+
73
+ for n, p in model.named_parameters():
74
+ # if p.requires_grad:
75
+ if apply_decay(n):
76
+ if torch.distributed.get_rank() == 0:
77
+ logging.info(f"with wd: {n}")
78
+ params_with_wd.append(p)
79
+ else:
80
+ if torch.distributed.get_rank() == 0:
81
+ logging.info(f"without wd: {n}")
82
+ params_without_wd.append(p)
83
+ return [
84
+ {"params": params_with_wd, "weight_decay": args.weight_decay},
85
+ {"params": params_without_wd, "weight_decay": 0.0},
86
+ ]
87
+
88
+
89
+ def lambda_policy_fn(module):
90
+ if (
91
+ len(list(module.named_children())) == 0
92
+ and getattr(module, "weight", None) is not None
93
+ and module.weight.requires_grad
94
+ ):
95
+ return True
96
+ return False
97
+
98
+
99
+ def lambda_auto_wrap_policy(
100
+ module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
101
+ ) -> bool:
102
+ """
103
+ A convenient auto wrap policy to wrap submodules based on an arbitrary user
104
+ function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
105
+ a `wrapper_cls` unit.
106
+
107
+ Return if a module should be wrapped during auto wrapping.
108
+
109
+ The first three parameters are required by :func:`_recursive_wrap`.
110
+
111
+ Args:
112
+ module (nn.Module): Current module being considered.
113
+ recurse (bool): If ``False``, then this function must decide whether
114
+ ``module`` should be wrapped as an FSDP instance or not. If
115
+ ``True``, then the function is still recursing down the module
116
+ tree as a part of the DFS.
117
+ nonwrapped_numel (int): Parameter numel not yet wrapped.
118
+
119
+ lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
120
+ this module will be wrapped.
121
+ """
122
+ if recurse:
123
+ return True # always recurse
124
+ return lambda_fn(module)
125
+
126
+
127
+ def main():
128
+ parser = argparse.ArgumentParser()
129
+ parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
130
+ parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
131
+ parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
132
+ parser.add_argument(
133
+ "--tokenizer_path",
134
+ default="facebook/opt-1.3b",
135
+ type=str,
136
+ help="path to tokenizer",
137
+ )
138
+ parser.add_argument(
139
+ "--run_name",
140
+ type=str,
141
+ default="openflamingo3B",
142
+ help="used to name saving directory and wandb run",
143
+ )
144
+ parser.add_argument("--use_media_placement_augmentation", action="store_true")
145
+ parser.add_argument("--offline", action="store_true")
146
+ parser.add_argument("--num_steps", type=int, default=300000)
147
+ parser.add_argument(
148
+ "--logging_steps", type=int, default=10, help="log loss every n steps"
149
+ )
150
+ # Sum of gradient optimization batch size
151
+ parser.add_argument("--batch_size_mmc4", type=int, default=128)
152
+ parser.add_argument("--batch_size_laion", type=int, default=128)
153
+ parser.add_argument("--batch_size_pile", type=int, default=128)
154
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
155
+ parser.add_argument(
156
+ "--resume_from_checkpoint",
157
+ type=str,
158
+ help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
159
+ default=None,
160
+ )
161
+ parser.add_argument(
162
+ "--delete_previous_checkpoint",
163
+ action="store_true",
164
+ help="delete previous checkpoint when saving new checkpoint",
165
+ )
166
+ parser.add_argument(
167
+ "--laion_shards",
168
+ type=str,
169
+ help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
170
+ )
171
+ parser.add_argument(
172
+ "--mmc4_shards",
173
+ type=str,
174
+ help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
175
+ )
176
+ parser.add_argument(
177
+ "--pile_shards",
178
+ type=str,
179
+ default=None,
180
+ help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
181
+ )
182
+ parser.add_argument("--seed", type=int, default=42)
183
+ parser.add_argument("--learning_rate", default=1e-4, type=float)
184
+ parser.add_argument(
185
+ "--lr_scheduler",
186
+ default="constant",
187
+ type=str,
188
+ help="constant, linear, or cosine",
189
+ )
190
+ parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
191
+ parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
192
+ parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
193
+ parser.add_argument("--loss_multiplier_det", type=float, default=1.0)
194
+ parser.add_argument("--loss_multiplier_rel", type=float, default=1.0)
195
+ parser.add_argument("--loss_multiplier_attn", type=float, default=1.0)
196
+ parser.add_argument("--warmup_steps", default=5000, type=int)
197
+ # weight decay is only apply to YOLOX head if using FSDP
198
+ # https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159
199
+ parser.add_argument("--weight_decay", default=0.05, type=float)
200
+ parser.add_argument(
201
+ "--precision",
202
+ choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
203
+ default="fp32",
204
+ help="Floating point precision.",
205
+ )
206
+ # data args
207
+ parser.add_argument("--workers", type=int, default=1)
208
+ parser.add_argument("--dataset_resampled", action="store_true")
209
+ # distributed training args
210
+ parser.add_argument(
211
+ "--dist-url",
212
+ default="env://",
213
+ type=str,
214
+ help="url used to set up distributed training",
215
+ )
216
+ parser.add_argument(
217
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
218
+ )
219
+ parser.add_argument(
220
+ "--horovod",
221
+ default=False,
222
+ action="store_true",
223
+ help="Use horovod for distributed training.",
224
+ )
225
+ parser.add_argument(
226
+ "--no-set-device-rank",
227
+ default=False,
228
+ action="store_true",
229
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
230
+ )
231
+ # wandb args
232
+ parser.add_argument("--report_to_wandb", default=False, action="store_true")
233
+ parser.add_argument(
234
+ "--wandb_project",
235
+ type=str,
236
+ )
237
+ parser.add_argument(
238
+ "--wandb_entity",
239
+ type=str,
240
+ )
241
+ parser.add_argument(
242
+ "--save_checkpoints_to_wandb",
243
+ default=False,
244
+ action="store_true",
245
+ help="save checkpoints to wandb",
246
+ )
247
+ parser.add_argument(
248
+ "--checkpoint_activations",
249
+ default=False,
250
+ action="store_true",
251
+ )
252
+ parser.add_argument(
253
+ "--freeze_vision_encoder",
254
+ default=False,
255
+ action="store_true",
256
+ )
257
+ parser.add_argument(
258
+ "--mmc4_textsim_threshold",
259
+ default=30,
260
+ type=float,
261
+ help="threshold for filtering images in mmc4 based on image-text similarity",
262
+ )
263
+ parser.add_argument(
264
+ "--location_token_num",
265
+ default=1000,
266
+ type=int,
267
+ )
268
+ parser.add_argument(
269
+ "--vis_embed_size",
270
+ type=int,
271
+ required=False,
272
+ )
273
+ parser.add_argument(
274
+ "--save_interval",
275
+ default=1000,
276
+ type=int,
277
+ required=False,
278
+ )
279
+ parser.add_argument(
280
+ "--skip_delete_pattern",
281
+ default=1500,
282
+ type=int,
283
+ required=False,
284
+ )
285
+ parser.add_argument(
286
+ "--ddp",
287
+ default=False,
288
+ action="store_true",
289
+ )
290
+ parser.add_argument(
291
+ "--pile_freq",
292
+ default=1,
293
+ type=int,
294
+ required=False,
295
+ )
296
+ parser.add_argument(
297
+ "--restart",
298
+ default=False,
299
+ action="store_true",
300
+ )
301
+ parser.add_argument(
302
+ "--lora",
303
+ default=False,
304
+ action="store_true",
305
+ )
306
+ parser.add_argument(
307
+ "--lora_r",
308
+ default=16,
309
+ type=int,
310
+ required=False,
311
+ )
312
+ parser.add_argument(
313
+ "--single",
314
+ default=False,
315
+ action="store_true",
316
+ )
317
+
318
+ # Finetune
319
+ parser.add_argument(
320
+ "--instruct",
321
+ default=False,
322
+ action="store_true",
323
+ )
324
+ parser.add_argument(
325
+ "--fix-ffn",
326
+ default=False,
327
+ action="store_true",
328
+ )
329
+ parser.add_argument(
330
+ "--prob_ground",
331
+ default=1.0,
332
+ type=float,
333
+ required=False,
334
+ )
335
+ parser.add_argument(
336
+ "--optimizer",
337
+ default="adamw",
338
+ type=str,
339
+ required=False,
340
+ )
341
+ parser.add_argument(
342
+ "--add_visual_token",
343
+ default=False,
344
+ action="store_true",
345
+ )
346
+ parser.add_argument(
347
+ "--use_format_v2",
348
+ default=False,
349
+ action="store_true",
350
+ )
351
+ parser.add_argument(
352
+ "--use_sam",
353
+ default=None,
354
+ type=str,
355
+ required=False,
356
+ )
357
+ parser.add_argument(
358
+ "--max-length",
359
+ default=608,
360
+ type=int,
361
+ required=False,
362
+ )
363
+ parser.add_argument(
364
+ "--image-size",
365
+ default=256,
366
+ type=int,
367
+ required=False,
368
+ )
369
+ parser.add_argument(
370
+ "--reset_llm",
371
+ default=False,
372
+ action="store_true",
373
+ )
374
+ parser.add_argument(
375
+ "--add_box",
376
+ default=False,
377
+ action="store_true",
378
+ )
379
+ parser.add_argument(
380
+ "--add_pe",
381
+ default=False,
382
+ action="store_true",
383
+ )
384
+ parser.add_argument(
385
+ "--only_grounded_sample",
386
+ default=False,
387
+ action="store_true",
388
+ )
389
+ parser.add_argument(
390
+ "--expand",
391
+ default=False,
392
+ action="store_true",
393
+ )
394
+ parser.add_argument(
395
+ "--delete_contained",
396
+ default=False,
397
+ action="store_true",
398
+ )
399
+
400
+ parser.add_argument(
401
+ "--relation",
402
+ default=False,
403
+ action="store_true",
404
+ )
405
+ parser.add_argument(
406
+ "--attn_reg",
407
+ default="l1",
408
+ type=str,
409
+ required=False,
410
+ )
411
+ parser.add_argument(
412
+ "--enhance_data",
413
+ default=False,
414
+ action="store_true",
415
+ )
416
+ parser.add_argument(
417
+ "--no_visual",
418
+ default=False,
419
+ action="store_true",
420
+ )
421
+ parser.add_argument(
422
+ "--no_previsual",
423
+ default=False,
424
+ action="store_true",
425
+ )
426
+ parser.add_argument(
427
+ "--roi_align",
428
+ default=False,
429
+ action="store_true",
430
+ )
431
+ parser.add_argument(
432
+ "--roi_output_size",
433
+ default=4,
434
+ type=int,
435
+ required=False,
436
+ )
437
+ parser.add_argument(
438
+ "--apply_mask",
439
+ default=False,
440
+ action="store_true",
441
+ )
442
+ parser.add_argument(
443
+ "--longer_previsual",
444
+ default=False,
445
+ action="store_true",
446
+ )
447
+
448
+ args = parser.parse_args()
449
+ assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
450
+ if args.no_previsual:
451
+ assert args.no_visual, "no_previsual MUST come with no_visual"
452
+ assert not args.enhance_data, "dont enable enhance_data"
453
+
454
+ if args.offline:
455
+ os.environ["WANDB_MODE"] = "offline"
456
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
457
+
458
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
459
+ print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
460
+ device_id = init_distributed_device(args)
461
+
462
+ random_seed(args.seed)
463
+ model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
464
+ args.vision_encoder_path,
465
+ args.vision_encoder_pretrained,
466
+ args.lm_path,
467
+ args.tokenizer_path if args.tokenizer_path else args.lm_path,
468
+ use_local_files=args.offline,
469
+ use_media_placement_augmentation=args.use_media_placement_augmentation,
470
+ checkpoint_activations=args.checkpoint_activations,
471
+ freeze_vision_encoder=args.freeze_vision_encoder,
472
+ location_token_num=args.location_token_num,
473
+ lora=args.lora,
474
+ lora_r=args.lora_r,
475
+ fix_ffn=args.fix_ffn,
476
+ add_visual_token=args.add_visual_token,
477
+ add_box=args.add_box,
478
+ add_pe=args.add_pe,
479
+ add_relation=args.relation,
480
+ use_format_v2=args.use_format_v2,
481
+ use_sam=args.use_sam,
482
+ enhance_data=args.enhance_data,
483
+ roi_align=args.roi_align,
484
+ roi_output_size=args.roi_output_size,
485
+ apply_mask=args.apply_mask,
486
+ )
487
+ if args.reset_llm:
488
+ llm_state_dict = model.lang_encoder.state_dict()
489
+ if args.rank == 0:
490
+ print(args)
491
+ print(image_processor)
492
+
493
+ random_seed(args.seed, args.rank)
494
+
495
+ if args.rank == 0 and args.report_to_wandb:
496
+ wandb.init(
497
+ project=args.wandb_project,
498
+ entity=args.wandb_entity,
499
+ name=args.run_name,
500
+ config=vars(args),
501
+ )
502
+
503
+ device_id = args.rank % torch.cuda.device_count()
504
+ if args.ddp:
505
+ print("use ddp mode")
506
+ model = model.to(device_id)
507
+ model = DDP(model)
508
+ else:
509
+ fpSixteen = MixedPrecision(
510
+ param_dtype=torch.float16,
511
+ # Gradient communication precision.
512
+ reduce_dtype=torch.float16,
513
+ # Buffer precision.
514
+ # buffer_dtype=torch.float16,
515
+ )
516
+ # from transformers.models.opt.modeling_opt import OPTDecoderLayer
517
+ from open_clip.transformer import ResidualAttentionBlock
518
+ from open_flamingo.src.flamingo_lm import FlamingoLayer
519
+ from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
520
+ from segment_anything.modeling.image_encoder import Block
521
+ transformer_layer_cls=[
522
+ FlamingoLayer,
523
+ ResidualAttentionBlock,
524
+ Block,
525
+ ]
526
+ if args.fix_ffn:
527
+ transformer_layer_cls.append(OPTAttention)
528
+ auto_wrap_policy = functools.partial(
529
+ transformer_auto_wrap_policy,
530
+ transformer_layer_cls=transformer_layer_cls,
531
+ )
532
+ if args.lora:
533
+ from torch.distributed.fsdp.wrap import _or_policy
534
+ lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
535
+ auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
536
+ ignored_modules = [model.vision_encoder]
537
+ # ignored_modules = None
538
+ else:
539
+ ignored_modules = [model.detection_head]
540
+ # ignored_modules = None
541
+ if args.add_pe:
542
+ ignored_modules += [model.pos_enc]
543
+ # if args.use_format_v2:
544
+ # ignored_modules += [model.lang_encoder.visual_guided_lm_head]
545
+ model = FSDP(
546
+ model,
547
+ auto_wrap_policy=auto_wrap_policy,
548
+ mixed_precision=fpSixteen,
549
+ device_id=torch.cuda.current_device(),
550
+ ignored_modules=ignored_modules,
551
+ sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
552
+ )
553
+ model = model.to(device_id)
554
+
555
+
556
+ pile_dataset = None
557
+ if args.instruct:
558
+ laion_dataset = get_data(args, image_processor, tokenizer, "instruct")
559
+ else:
560
+ laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
561
+ if args.pile_shards is not None:
562
+ pile_dataset = get_data(args, image_processor, tokenizer, "pile")
563
+
564
+
565
+ optim_groups = get_grouped_params(model, args)
566
+ # optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
567
+ if args.ddp:
568
+ optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
569
+ # optimizer = ZeroRedundancyOptimizer(
570
+ # optim_groups,
571
+ # optimizer_class=torch.optim.AdamW,
572
+ # lr=args.learning_rate,
573
+ # parameters_as_bucket_view=True,
574
+ # )
575
+ else:
576
+ if args.optimizer == "adamw":
577
+ print("use adamw")
578
+ optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
579
+ elif args.optimizer == "sgd":
580
+ print("use sgd...")
581
+ optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
582
+ else:
583
+ raise NotImplementedError
584
+
585
+ total_training_steps = args.num_steps
586
+
587
+ if args.rank == 0:
588
+ logging.info(f"Total training steps: {total_training_steps}")
589
+
590
+ if args.lr_scheduler == "linear":
591
+ lr_scheduler = get_linear_schedule_with_warmup(
592
+ optimizer,
593
+ num_warmup_steps=args.warmup_steps,
594
+ num_training_steps=total_training_steps,
595
+ )
596
+ elif args.lr_scheduler == "cosine":
597
+ lr_scheduler = get_cosine_schedule_with_warmup(
598
+ optimizer,
599
+ num_warmup_steps=args.warmup_steps,
600
+ num_training_steps=total_training_steps,
601
+ )
602
+ else:
603
+ lr_scheduler = get_constant_schedule_with_warmup(
604
+ optimizer, num_warmup_steps=args.warmup_steps
605
+ )
606
+ if args.ddp:
607
+ scaler = GradScaler()
608
+ else:
609
+ scaler = ShardedGradScaler()
610
+ total_laion_token = 0
611
+ total_pile_token = 0
612
+ total_laion_sample = 0
613
+ total_step = 0
614
+
615
+ # check if a checkpoint exists for this run
616
+ if os.path.exists(f"{args.run_name}"):
617
+ checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
618
+ if len(checkpoint_list) == 0:
619
+ if args.rank == 0:
620
+ logging.info(f"Found no checkpoints for run {args.run_name}.")
621
+ else:
622
+ args.resume_from_checkpoint = sorted(
623
+ checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
624
+ )[-1]
625
+ if args.rank == 0:
626
+ logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
627
+ args.restart = False
628
+ if args.rank == 0:
629
+ logging.info("do not restart because an existed checkpoint is found")
630
+ if args.resume_from_checkpoint is not None:
631
+ if args.rank == 0:
632
+ logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}")
633
+ checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
634
+ torch.distributed.barrier()
635
+ if args.ddp:
636
+ model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
637
+ # sharded_osd = checkpoint['optimizer_state_dict']
638
+ else:
639
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
640
+ if args.reset_llm:
641
+ for key in checkpoint["model_state_dict"]:
642
+ if key.startswith("lang_encoder"):
643
+ if args.rank == 0:
644
+ logging.info(f"reset {key}")
645
+ llm_key = key.replace("lang_encoder.", "")
646
+ checkpoint["model_state_dict"][key] = llm_state_dict[llm_key]
647
+ model_state_dict = model.state_dict()
648
+ for key in checkpoint["model_state_dict"].keys():
649
+ if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape:
650
+ if args.rank == 0:
651
+ logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}')
652
+ checkpoint["model_state_dict"][key] = model_state_dict[key].clone()
653
+ del model_state_dict
654
+ model.load_state_dict(checkpoint["model_state_dict"], False)
655
+ # sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
656
+ if not args.restart:
657
+ # optimizer.load_state_dict(sharded_osd)
658
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
659
+ # scaler.load_state_dict(checkpoint["scaler_state_dict"])
660
+ total_laion_token = checkpoint.get("total_laion_token", 0)
661
+ total_pile_token = checkpoint.get("total_pile_token", 0)
662
+ total_laion_sample = checkpoint.get("total_laion_sample", 0)
663
+ total_step = checkpoint.get("total_step", 0)
664
+ if args.rank == 0:
665
+ logging.info("load training statistics...")
666
+ else:
667
+ if args.rank == 0:
668
+ logging.info("restart training / finetuning. only load model weight...")
669
+ del checkpoint
670
+ if args.reset_llm:
671
+ del llm_state_dict
672
+ torch.cuda.empty_cache()
673
+ torch.distributed.barrier()
674
+
675
+ model.train()
676
+ if args.rank == 0:
677
+ if not os.path.exists(args.run_name):
678
+ os.makedirs(args.run_name)
679
+ writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
680
+ else:
681
+ writer = None
682
+
683
+ laion_dataset.set_epoch(total_step)
684
+ laion_loader = laion_dataset.dataloader
685
+ if pile_dataset is not None:
686
+ pile_dataset.set_epoch(total_step)
687
+ pile_loader = pile_dataset.dataloader
688
+ else:
689
+ pile_loader = FakeDataloader()
690
+ train_one_epoch(
691
+ args=args,
692
+ model=model,
693
+ tokenizer=tokenizer,
694
+ optimizer=optimizer,
695
+ lr_scheduler=lr_scheduler,
696
+ laion_loader=laion_loader,
697
+ pile_loader=pile_loader,
698
+ device_id=device_id,
699
+ writer=writer,
700
+ scaler=scaler,
701
+ optim_groups=optim_groups,
702
+ total_laion_token=total_laion_token,
703
+ total_pile_token=total_pile_token,
704
+ total_laion_sample=total_laion_sample,
705
+ total_step=total_step,
706
+ )
707
+
708
+ if __name__ == "__main__":
709
+ main()
multimodal/build/lib/open_flamingo/train/train_utils.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from contextlib import suppress
3
+ import numpy as np
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+ import datetime
8
+ import os
9
+ import gc
10
+ from torch.distributed.fsdp import (
11
+ FullyShardedDataParallel as FSDP,
12
+ MixedPrecision,
13
+ BackwardPrefetch,
14
+ ShardingStrategy,
15
+ FullStateDictConfig,
16
+ StateDictType,
17
+ )
18
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
19
+ from torch.distributed.fsdp.wrap import (
20
+ transformer_auto_wrap_policy,
21
+ enable_wrap,
22
+ wrap,
23
+ )
24
+
25
+ from torch.utils.tensorboard import SummaryWriter
26
+ import logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s %(message)s',
30
+ datefmt='%m/%d %I:%M:%S',
31
+ )
32
+
33
+ def get_cast_dtype(precision: str):
34
+ cast_dtype = None
35
+ if precision == "bf16":
36
+ cast_dtype = torch.bfloat16
37
+ elif precision == "fp16":
38
+ cast_dtype = torch.float16
39
+ return cast_dtype
40
+
41
+
42
+ def get_autocast(precision):
43
+ if precision == "amp_fp16":
44
+ return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
45
+ elif precision == "amp_bfloat16" or precision == "amp_bf16":
46
+ # amp_bfloat16 is more stable than amp float16 for clip training
47
+ return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
48
+ else:
49
+ return suppress
50
+
51
+
52
+ def get_sync(model, flag):
53
+ if flag:
54
+ return suppress
55
+ else:
56
+ return lambda: model.no_sync()
57
+
58
+
59
+ def train_one_epoch(
60
+ args,
61
+ model,
62
+ laion_loader,
63
+ pile_loader,
64
+ tokenizer,
65
+ optimizer,
66
+ lr_scheduler,
67
+ device_id,
68
+ writer: SummaryWriter,
69
+ optim_groups,
70
+ scaler,
71
+ total_laion_token: int,
72
+ total_pile_token: int,
73
+ total_laion_sample: int,
74
+ total_step: int,
75
+ ):
76
+ world_size = torch.distributed.get_world_size()
77
+ autocast = get_autocast(args.precision)
78
+ cast_dtype = get_cast_dtype(args.precision)
79
+
80
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
81
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
82
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
83
+ if args.add_box:
84
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
85
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
86
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
87
+ if args.use_format_v2:
88
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
89
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
90
+ if args.rank == 0:
91
+ logging.info(f"train from: {total_step} step")
92
+ model.train()
93
+ # loop through dataloader
94
+ last_logging_step = total_step
95
+ last_save_step = total_step
96
+ for num_steps, (batch_laion, batch_pile) in tqdm(
97
+ enumerate(zip(laion_loader, pile_loader)),
98
+ disable=args.rank != 0 or "SLURM_PROCID" in os.environ,
99
+ total=args.num_steps * args.gradient_accumulation_steps,
100
+ initial=total_step * args.gradient_accumulation_steps,
101
+ ):
102
+ #### LAION FORWARD PASS ####
103
+ images = (
104
+ batch_laion[0]
105
+ .to(device_id, dtype=cast_dtype, non_blocking=True)
106
+ .unsqueeze(1)
107
+ .unsqueeze(1)
108
+ )
109
+ image_nums = batch_laion[1]
110
+ image_start_index_list = batch_laion[2]
111
+
112
+ # TODO: OPT model: input_ids is not started with </s> while input_ids2 is?
113
+ input_ids = batch_laion[3].to(device_id, non_blocking=True).long()
114
+ attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True)
115
+ added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object
116
+ total_laion_token += int(attention_mask.sum().long()) * world_size
117
+ total_laion_sample += sum(image_nums) * world_size
118
+
119
+ labels = input_ids.clone()
120
+ if args.add_box:
121
+ labels[input_ids == visual_token_id] = -100
122
+ labels[input_ids == box_token_id] = -100
123
+ labels[input_ids == endofattr_token_id] = -100
124
+ if args.use_format_v2:
125
+ labels[input_ids == previsual_token_id] = -100
126
+ labels[input_ids == prebox_token_id] = -100
127
+ labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
128
+ labels[torch.roll(input_ids == box_token_id, 1)] = -100
129
+ labels[:, 0] = -100
130
+ labels[input_ids == tokenizer.pad_token_id] = -100
131
+ labels[input_ids == media_token_id] = -100
132
+ labels[input_ids == endofmedia_token_id] = -100
133
+ labels.to(device_id)
134
+ current_laion_num = input_ids.shape[0]
135
+
136
+ #### PILE FORWARD PASS ####
137
+ if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None:
138
+ input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long()
139
+ attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True)
140
+ input_length = input_ids.shape[-1]
141
+
142
+ input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1)
143
+ attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1)
144
+
145
+ labels2 = input_ids2.clone()
146
+ labels2[labels2 == tokenizer.pad_token_id] = -100
147
+ labels2[:, 0] = -100
148
+ labels2.to(device_id)
149
+
150
+ if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1:
151
+ image_nums = image_nums + [0] * len(input_ids2)
152
+ image_start_index_list = image_start_index_list + [[]] * len(input_ids2)
153
+ input_ids = torch.cat([input_ids, input_ids2], dim=0)
154
+ attention_mask = torch.cat([attention_mask, attention_mask2], dim=0)
155
+ labels = torch.cat([labels, labels2], dim=0)
156
+ total_pile_token += int(attention_mask2.sum().long()) * world_size
157
+ else:
158
+ del input_ids2
159
+ del attention_mask2
160
+ del labels2
161
+
162
+ if args.instruct:
163
+ answer_token_id = tokenizer(" Answer").input_ids[0]
164
+ answer_token_loc = (input_ids == answer_token_id).nonzero()
165
+ for batch_idx, idx in answer_token_loc:
166
+ labels[batch_idx][:idx+2] = -100
167
+
168
+ if args.relation and not args.instruct:
169
+ relations = batch_laion[6]
170
+ else:
171
+ relations = None
172
+ if len(added_bbox_list) == 0:
173
+ added_bbox_list = None
174
+ update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1
175
+ # do_sync = get_sync(model, update_flag)
176
+ with autocast():
177
+ # modify:
178
+ # /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py
179
+ # /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py
180
+ # CrossEntropyLoss(reduction="none")
181
+ outputs = model(
182
+ vision_x=images,
183
+ lang_x=input_ids,
184
+ attention_mask=attention_mask,
185
+ labels=labels,
186
+ image_nums=image_nums,
187
+ image_start_index_list=image_start_index_list,
188
+ added_bbox_list=added_bbox_list,
189
+ add_box=args.add_box,
190
+ relations=relations,
191
+ )
192
+ loss_total = outputs.loss.reshape(labels.shape[0], -1)
193
+ loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1)
194
+ loss_sample_for_laion = loss_sample[:current_laion_num]
195
+ nan_mask = torch.isnan(loss_sample_for_laion)
196
+ if nan_mask.sum() > 0:
197
+ logging.warning(f"caption NaN: {nan_mask}")
198
+ if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid:
199
+ logging.info("WARNING: skip this caption loss due to some error")
200
+ loss_laion = torch.tensor(0.0).cuda()
201
+ else:
202
+ loss_laion = loss_sample_for_laion[~nan_mask].mean()
203
+ loss_caption = loss_laion
204
+ divided_loss_laion = loss_laion / args.gradient_accumulation_steps
205
+ if current_laion_num != loss_sample.shape[0]:
206
+ loss_pile = loss_sample[current_laion_num:].mean()
207
+ else:
208
+ loss_pile = torch.tensor(0.0).cuda()
209
+ divided_loss_pile = loss_pile / args.gradient_accumulation_steps
210
+
211
+ if "detection_losses" in outputs:
212
+ loss_det = outputs["detection_losses"]["loss"]
213
+ loss_iou = outputs["detection_losses"]["loss_iou"]
214
+ loss_obj = outputs["detection_losses"]["loss_obj"]
215
+ loss_cls = outputs["detection_losses"]["loss_cls"]
216
+ else:
217
+ loss_det = torch.tensor(0.0).cuda()
218
+ loss_iou = torch.tensor(0.0).cuda()
219
+ loss_obj = torch.tensor(0.0).cuda()
220
+ loss_cls = torch.tensor(0.0).cuda()
221
+
222
+ if "loss_dict" in outputs:
223
+ visual_loss_iou = outputs["loss_dict"][0]["loss_iou"]
224
+ previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"]
225
+ visual_loss_obj = outputs["loss_dict"][0]["loss_obj"]
226
+ previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"]
227
+ else:
228
+ visual_loss_iou = torch.tensor(0.0).cuda()
229
+ previsual_loss_iou = torch.tensor(0.0).cuda()
230
+ visual_loss_obj = torch.tensor(0.0).cuda()
231
+ previsual_loss_obj = torch.tensor(0.0).cuda()
232
+
233
+ divided_loss_det = loss_det / args.gradient_accumulation_steps
234
+ loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda())
235
+ divided_loss_rel = loss_rel / args.gradient_accumulation_steps
236
+ loss = (
237
+ divided_loss_laion * args.loss_multiplier_laion +
238
+ divided_loss_pile * args.loss_multiplier_pile +
239
+ divided_loss_det * args.loss_multiplier_det +
240
+ divided_loss_rel * args.loss_multiplier_rel
241
+ )
242
+
243
+ scaler.scale(loss).backward()
244
+
245
+ # for logging only
246
+ loss = (
247
+ loss_laion * args.loss_multiplier_laion
248
+ + loss_pile * args.loss_multiplier_pile
249
+ + loss_det * args.loss_multiplier_det
250
+ + loss_rel * args.loss_multiplier_rel
251
+ ).detach()
252
+
253
+ # step optimizer and log
254
+ if update_flag:
255
+ #### MASK GRADIENTS FOR EMBEDDINGS ####
256
+ # Note (anas): Do not apply weight decay to embeddings as it will break this function.
257
+ # ! not an important point
258
+ # if args.ddp:
259
+ # def mask_embedding(m):
260
+ # if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
261
+ # zero_mask = torch.zeros_like(m.weight.grad)
262
+ # zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
263
+ # zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id])
264
+ # m.weight.grad = m.weight.grad * zero_mask
265
+ # model.apply(mask_embedding)
266
+ total_step += 1
267
+ scaler.unscale_(optimizer)
268
+ if args.ddp:
269
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
270
+ else:
271
+ model.clip_grad_norm_(1.0)
272
+ scaler.step(optimizer)
273
+ scaler.update()
274
+ lr_scheduler.step()
275
+ optimizer.zero_grad()
276
+ # https://github.com/facebookresearch/fairscale/issues/627
277
+ model.zero_grad(set_to_none=True)
278
+
279
+ if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step:
280
+ last_logging_step = total_step
281
+ global_step = total_step
282
+ lr = optimizer.param_groups[0]["lr"]
283
+ writer.add_scalar("lr", lr, global_step)
284
+ writer.add_scalar("scale", scaler.get_scale(), global_step)
285
+ writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step)
286
+ writer.add_scalar("loss_laion", loss_caption.item(), global_step)
287
+ writer.add_scalar("loss_pile", loss_pile.item(), global_step)
288
+ writer.add_scalar("loss", loss.item(), global_step)
289
+ writer.add_scalar("loss_det", loss_det.item(), global_step)
290
+ writer.add_scalar("loss_iou", loss_iou.item(), global_step)
291
+ writer.add_scalar("loss_obj", loss_obj.item(), global_step)
292
+ writer.add_scalar("loss_cls", loss_cls.item(), global_step)
293
+ if loss_rel.item() != 0:
294
+ writer.add_scalar("loss_rel", loss_rel.item(), global_step)
295
+ if args.use_format_v2:
296
+ writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step)
297
+ writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step)
298
+ writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step)
299
+ writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step)
300
+
301
+ global_sample_num = total_laion_sample
302
+ writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num)
303
+ writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num)
304
+ writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num)
305
+ writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num)
306
+ writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num)
307
+ writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num)
308
+ writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num)
309
+ if loss_rel.item() != 0:
310
+ writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num)
311
+ writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num)
312
+
313
+ writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token)
314
+ writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token)
315
+ writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token)
316
+ writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token)
317
+ writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token)
318
+ writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token)
319
+ writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token)
320
+ if loss_rel.item() != 0:
321
+ writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token)
322
+
323
+ total_token = total_laion_token + total_pile_token
324
+ writer.add_scalar("sample_num", global_sample_num, global_step)
325
+ writer.add_scalar("total_laion_token", total_laion_token, global_step)
326
+ writer.add_scalar("total_pile_token", total_pile_token, global_step)
327
+ writer.add_scalar("total_token", total_token, global_step)
328
+ logging.info(
329
+ f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} // laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}"
330
+ )
331
+
332
+ if total_step % args.save_interval == 0 and total_step != last_save_step:
333
+ last_save_step = total_step
334
+ torch.distributed.barrier()
335
+ if args.ddp:
336
+ cpu_state = model.state_dict()
337
+ # if args.rank == 0:
338
+ # optimizer_state = optimizer.state_dict()
339
+ else:
340
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
341
+ with FSDP.state_dict_type(
342
+ model, StateDictType.FULL_STATE_DICT, save_policy
343
+ ):
344
+ cpu_state = model.state_dict()
345
+ torch.distributed.barrier()
346
+ # https://pytorch.org/docs/1.12/fsdp.html
347
+ # need to pass optim_groups as optim_input
348
+ # optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups)
349
+ if args.rank == 0:
350
+ checkpoint_dict = {
351
+ "model_state_dict": cpu_state,
352
+ # "optimizer_state_dict": optimizer_state,
353
+ "lr_scheduler_state_dict": lr_scheduler.state_dict(),
354
+ "scaler_state_dict": scaler.state_dict(),
355
+ "total_pile_token": total_pile_token,
356
+ "total_laion_token": total_laion_token,
357
+ "total_laion_sample": total_laion_sample,
358
+ "total_step": total_step,
359
+ }
360
+ logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt")
361
+ torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt")
362
+ del checkpoint_dict
363
+ if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0:
364
+ try:
365
+ os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt")
366
+ except:
367
+ pass
368
+ torch.distributed.barrier()
369
+
370
+
371
+ class AverageMeter(object):
372
+ """Computes and stores the average and current value"""
373
+
374
+ def __init__(self):
375
+ self.reset()
376
+
377
+ def reset(self):
378
+ self.val = 0
379
+ self.avg = 0
380
+ self.sum = 0
381
+ self.count = 0
382
+
383
+ def update(self, val, n=1):
384
+ self.val = val
385
+ self.sum += val * n
386
+ self.count += n
387
+ self.avg = self.sum / self.count
multimodal/open_flamingo.egg-info/PKG-INFO ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: open-flamingo
3
+ Version: 0.0.2
4
+ Summary: An open-source framework for training large multimodal models
5
+ License: MIT
6
+ Keywords: machine learning
7
+ Classifier: Development Status :: 4 - Beta
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+
15
+ # 🦩 OpenFlamingo
16
+
17
+ [![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo)
18
+
19
+ [Blog post](https://laion.ai/blog/open-flamingo/) | Paper (coming soon)
20
+
21
+ Welcome to our open source version of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) model! In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. We also provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) trained on a new Multimodal C4 dataset (coming soon). Please refer to our blog post for more details.
22
+
23
+ This repo is still under development, and we hope to release better performing and larger OpenFlamingo models soon. If you have any questions, please feel free to open an issue. We also welcome contributions!
24
+
25
+ # Table of Contents
26
+ - [Installation](#installation)
27
+ - [Approach](#approach)
28
+ * [Model architecture](#model-architecture)
29
+ - [Usage](#usage)
30
+ * [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model)
31
+ * [Generating text](#generating-text)
32
+ - [Training](#training)
33
+ * [Dataset](#dataset)
34
+ - [Evaluation](#evaluation)
35
+ - [Future plans](#future-plans)
36
+ - [Team](#team)
37
+ - [Acknowledgments](#acknowledgments)
38
+ - [Citing](#citing)
39
+
40
+ # Installation
41
+
42
+ To install the package in an existing environment, run
43
+ ```
44
+ pip install open-flamingo
45
+ ```
46
+
47
+ or to create a conda environment for running OpenFlamingo, run
48
+ ```
49
+ conda env create -f environment.yml
50
+ ```
51
+
52
+ # Usage
53
+ We provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) using a CLIP ViT-Large vision encoder and a LLaMA-7B language model. In general, we support any [CLIP vision encoder](https://huggingface.co/models?search=clip). For the language model, we support [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models.
54
+
55
+ #### NOTE: To use LLaMA models, you will need to install the latest version of transformers via
56
+ ```
57
+ pip install git+https://github.com/huggingface/transformers
58
+ ```
59
+ Use this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to HuggingFace format.
60
+
61
+ ## Initializing an OpenFlamingo model
62
+ ``` python
63
+ from open_flamingo import create_model_and_transforms
64
+
65
+ model, image_processor, tokenizer = create_model_and_transforms(
66
+ clip_vision_encoder_path="ViT-L-14",
67
+ clip_vision_encoder_pretrained="openai",
68
+ lang_encoder_path="<path to llama weights in HuggingFace format>",
69
+ tokenizer_path="<path to llama tokenizer in HuggingFace format>",
70
+ cross_attn_every_n_layers=4
71
+ )
72
+
73
+ # grab model checkpoint from huggingface hub
74
+ from huggingface_hub import hf_hub_download
75
+ import torch
76
+
77
+ checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt")
78
+ model.load_state_dict(torch.load(checkpoint_path), strict=False)
79
+ ```
80
+
81
+ ## Generating text
82
+ Here is an example of generating text conditioned on interleaved images/text, in this case we will do few-shot image captioning.
83
+
84
+ ``` python
85
+ from PIL import Image
86
+ import requests
87
+
88
+ """
89
+ Step 1: Load images
90
+ """
91
+ demo_image_one = Image.open(
92
+ requests.get(
93
+ "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
94
+ ).raw
95
+ )
96
+
97
+ demo_image_two = Image.open(
98
+ requests.get(
99
+ "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
100
+ stream=True
101
+ ).raw
102
+ )
103
+
104
+ query_image = Image.open(
105
+ requests.get(
106
+ "http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
107
+ stream=True
108
+ ).raw
109
+ )
110
+
111
+
112
+ """
113
+ Step 2: Preprocessing images
114
+ Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
115
+ batch_size x num_media x num_frames x channels x height x width.
116
+ In this case batch_size = 1, num_media = 3, num_frames = 1
117
+ (this will always be one expect for video which we don't support yet),
118
+ channels = 3, height = 224, width = 224.
119
+ """
120
+ vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
121
+ vision_x = torch.cat(vision_x, dim=0)
122
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0)
123
+
124
+ """
125
+ Step 3: Preprocessing text
126
+ Details: In the text we expect an <|#image#|> special token to indicate where an image is.
127
+ We also expect an <|endofchunk|> special token to indicate the end of the text
128
+ portion associated with an image.
129
+ """
130
+ tokenizer.padding_side = "left" # For generation padding tokens should be on the left
131
+ lang_x = tokenizer(
132
+ ["<|#image#|>An image of two cats.<|endofchunk|><|#image#|>An image of a bathroom sink.<|endofchunk|><|#image#|>An image of"],
133
+ return_tensors="pt",
134
+ )
135
+
136
+
137
+ """
138
+ Step 4: Generate text
139
+ """
140
+ generated_text = model.generate(
141
+ vision_x=vision_x,
142
+ lang_x=lang_x["input_ids"],
143
+ attention_mask=lang_x["attention_mask"],
144
+ max_new_tokens=20,
145
+ num_beams=3,
146
+ )
147
+
148
+ print("Generated text: ", tokenizer.decode(generated_text[0]))
149
+ ```
150
+
151
+ # Approach
152
+ OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context training.
153
+
154
+ ## Model architecture
155
+ OpenFlamingo seeks to fuse a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.
156
+
157
+ ![OpenFlamingo architecture](docs/flamingo.png)
158
+ Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)
159
+
160
+ # Training
161
+ To train a model, modify the following example command, which uses OPT 1.3B as an example LM:
162
+ ```
163
+ torchrun --nnodes=1 --nproc_per_node=4 train.py \
164
+ --run_name flamingo3B \
165
+ --lm_path facebook/opt-1.3b \
166
+ --tokenizer_path facebook/opt-1.3b \
167
+ --dataset_resampled \
168
+ --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
169
+ --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
170
+ --batch_size_mmc4 4 \
171
+ --batch_size_laion 8 \
172
+ --train_num_samples_mmc4 125000 \
173
+ --train_num_samples_laion 250000 \
174
+ --loss_multiplier_laion 0.2 \
175
+ --workers=6 \
176
+ --num_epochs 250 \
177
+ --lr_scheduler constant \
178
+ --warmup_steps 5000 \
179
+ --use_media_placement_augmentation \
180
+ --mmc4_textsim_threshold 30
181
+ ```
182
+
183
+ ## Dataset
184
+ We expect all our training datasets to be [WebDataset](https://github.com/webdataset/webdataset) shards.
185
+ We train our models on the [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and Multimodal C4 (coming soon) datasets. By default the LAION 2B dataset is in WebDataset format if it is downloaded using the [img2dataset tool](https://github.com/rom1504/img2dataset) and Multimodal C4 comes packaged in the WebDataset format.
186
+
187
+
188
+ # Evaluation
189
+ We currently support running evaluations on [COCO](https://cocodataset.org/#home), [VQAv2](https://visualqa.org/index.html), [OKVQA](https://okvqa.allenai.org), [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset), and [ImageNet](https://image-net.org/index.php). Note that currently these evaluations are ran in validation mode (as specified in the Flamingo paper). We will be adding support for running evaluations in test mode in the future.
190
+
191
+ Before evaluating the model, you will need to install the coco evaluation package by running the following command:
192
+ ```
193
+ pip install pycocoevalcap
194
+ ```
195
+
196
+ To run evaluations on OKVQA you will need to run the following command:
197
+ ```
198
+ import nltk
199
+ nltk.download('wordnet')
200
+ ```
201
+
202
+ To evaluate the model, run the script at `open_flamingo/scripts/run_eval.sh`
203
+
204
+ # Future plans
205
+ - [ ] Add support for video input
206
+ - [ ] Release better performing and larger OpenFlamingo models
207
+ - [ ] Expand our evaluation suite
208
+ - [ ] Add support for FSDP training
209
+
210
+ # Team
211
+
212
+ OpenFlamingo is developed by:
213
+
214
+ [Anas Awadalla](https://anas-awadalla.streamlit.app/), [Irena Gao](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/).
215
+
216
+ The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google.
217
+
218
+ # Acknowledgments
219
+ This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design.
220
+
221
+ We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models.
222
+
223
+ # Citing
224
+ If you found this repository useful, please consider citing:
225
+
226
+ ```
227
+ @software{anas_awadalla_2023_7733589,
228
+ author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig},
229
+ title = {OpenFlamingo},
230
+ month = mar,
231
+ year = 2023,
232
+ publisher = {Zenodo},
233
+ version = {v0.1.1},
234
+ doi = {10.5281/zenodo.7733589},
235
+ url = {https://doi.org/10.5281/zenodo.7733589}
236
+ }
237
+ ```
238
+
239
+ ```
240
+ @article{Alayrac2022FlamingoAV,
241
+ title={Flamingo: a Visual Language Model for Few-Shot Learning},
242
+ author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
243
+ journal={ArXiv},
244
+ year={2022},
245
+ volume={abs/2204.14198}
246
+ }
247
+ ```
multimodal/open_flamingo.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ open_flamingo/__init__.py
5
+ open_flamingo.egg-info/PKG-INFO
6
+ open_flamingo.egg-info/SOURCES.txt
7
+ open_flamingo.egg-info/dependency_links.txt
8
+ open_flamingo.egg-info/requires.txt
9
+ open_flamingo.egg-info/top_level.txt
10
+ open_flamingo/chat/__init__.py
11
+ open_flamingo/chat/conversation.py
12
+ open_flamingo/eval/__init__.py
13
+ open_flamingo/eval/classification.py
14
+ open_flamingo/eval/coco_metric.py
15
+ open_flamingo/eval/eval_datasets.py
16
+ open_flamingo/eval/evaluate.py
17
+ open_flamingo/eval/evaluate_debug.py
18
+ open_flamingo/eval/evaluate_find_showcase.py
19
+ open_flamingo/eval/evaluate_temp.py
20
+ open_flamingo/eval/imagenet_utils.py
21
+ open_flamingo/eval/ok_vqa_utils.py
22
+ open_flamingo/eval/vqa_metric.py
23
+ open_flamingo/eval/dataset_zoo/__init__.py
24
+ open_flamingo/eval/dataset_zoo/aro_datasets.py
25
+ open_flamingo/eval/dataset_zoo/constants.py
26
+ open_flamingo/eval/dataset_zoo/perturbations.py
27
+ open_flamingo/eval/dataset_zoo/retrieval.py
28
+ open_flamingo/eval/dataset_zoo/utils.py
29
+ open_flamingo/eval/task/__init__.py
30
+ open_flamingo/eval/task/caption.py
31
+ open_flamingo/eval/task/caption_chat.py
32
+ open_flamingo/eval/task/cola.py
33
+ open_flamingo/eval/task/crepe.py
34
+ open_flamingo/eval/task/gqa.py
35
+ open_flamingo/eval/task/mmbench.py
36
+ open_flamingo/eval/task/reg.py
37
+ open_flamingo/eval/task/utils.py
38
+ open_flamingo/eval/task/vl_checklist.py
39
+ open_flamingo/src/__init__.py
40
+ open_flamingo/src/attention.py
41
+ open_flamingo/src/factory.py
42
+ open_flamingo/src/flamingo.py
43
+ open_flamingo/src/flamingo_lm.py
44
+ open_flamingo/src/gcn.py
45
+ open_flamingo/src/helpers.py
46
+ open_flamingo/src/utils.py
47
+ open_flamingo/train/__init__.py
48
+ open_flamingo/train/data2.py
49
+ open_flamingo/train/distributed.py
50
+ open_flamingo/train/instruction_template.py
51
+ open_flamingo/train/train.py
52
+ open_flamingo/train/train_utils.py
53
+ tests/test_flamingo_model.py
multimodal/open_flamingo.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
multimodal/open_flamingo.egg-info/requires.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ einops-exts
3
+ transformers==4.31.0
4
+ torch==1.12.1
5
+ torchvision==0.13.1
6
+ pillow==9.3.0
7
+ more-itertools
8
+ datasets==2.9.0
9
+ braceexpand==0.1.7
10
+ webdataset
11
+ wandb==0.13.10
12
+ nltk
13
+ scipy
14
+ inflection
15
+ sentencepiece
16
+ open_clip_torch==2.20.0
17
+ opencv-python==4.7.0.68
multimodal/open_flamingo.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ open_flamingo