Spaces:
Runtime error
Runtime error
update cap
Browse files- app.py +3 -2
- multimodal/build/lib/open_flamingo/__init__.py +2 -0
- multimodal/build/lib/open_flamingo/chat/__init__.py +0 -0
- multimodal/build/lib/open_flamingo/chat/conversation.py +571 -0
- multimodal/build/lib/open_flamingo/eval/__init__.py +1 -0
- multimodal/build/lib/open_flamingo/eval/classification.py +147 -0
- multimodal/build/lib/open_flamingo/eval/coco_metric.py +23 -0
- multimodal/build/lib/open_flamingo/eval/dataset_zoo/__init__.py +33 -0
- multimodal/build/lib/open_flamingo/eval/dataset_zoo/aro_datasets.py +365 -0
- multimodal/build/lib/open_flamingo/eval/dataset_zoo/constants.py +3 -0
- multimodal/build/lib/open_flamingo/eval/dataset_zoo/perturbations.py +194 -0
- multimodal/build/lib/open_flamingo/eval/dataset_zoo/retrieval.py +266 -0
- multimodal/build/lib/open_flamingo/eval/dataset_zoo/utils.py +15 -0
- multimodal/build/lib/open_flamingo/eval/eval_datasets.py +101 -0
- multimodal/build/lib/open_flamingo/eval/evaluate.py +1435 -0
- multimodal/build/lib/open_flamingo/eval/evaluate_debug.py +1159 -0
- multimodal/build/lib/open_flamingo/eval/evaluate_find_showcase.py +1700 -0
- multimodal/build/lib/open_flamingo/eval/evaluate_temp.py +1838 -0
- multimodal/build/lib/open_flamingo/eval/imagenet_utils.py +1007 -0
- multimodal/build/lib/open_flamingo/eval/ok_vqa_utils.py +213 -0
- multimodal/build/lib/open_flamingo/eval/task/__init__.py +0 -0
- multimodal/build/lib/open_flamingo/eval/task/caption.py +419 -0
- multimodal/build/lib/open_flamingo/eval/task/caption_chat.py +417 -0
- multimodal/build/lib/open_flamingo/eval/task/cola.py +220 -0
- multimodal/build/lib/open_flamingo/eval/task/crepe.py +93 -0
- multimodal/build/lib/open_flamingo/eval/task/gqa.py +248 -0
- multimodal/build/lib/open_flamingo/eval/task/mmbench.py +84 -0
- multimodal/build/lib/open_flamingo/eval/task/reg.py +141 -0
- multimodal/build/lib/open_flamingo/eval/task/utils.py +287 -0
- multimodal/build/lib/open_flamingo/eval/task/vl_checklist.py +113 -0
- multimodal/build/lib/open_flamingo/eval/vqa_metric.py +594 -0
- multimodal/build/lib/open_flamingo/src/__init__.py +0 -0
- multimodal/build/lib/open_flamingo/src/attention.py +45 -0
- multimodal/build/lib/open_flamingo/src/factory.py +269 -0
- multimodal/build/lib/open_flamingo/src/flamingo.py +637 -0
- multimodal/build/lib/open_flamingo/src/flamingo_lm.py +173 -0
- multimodal/build/lib/open_flamingo/src/gcn.py +137 -0
- multimodal/build/lib/open_flamingo/src/helpers.py +263 -0
- multimodal/build/lib/open_flamingo/src/utils.py +31 -0
- multimodal/build/lib/open_flamingo/train/__init__.py +1 -0
- multimodal/build/lib/open_flamingo/train/data2.py +868 -0
- multimodal/build/lib/open_flamingo/train/distributed.py +128 -0
- multimodal/build/lib/open_flamingo/train/instruction_template.py +13 -0
- multimodal/build/lib/open_flamingo/train/train.py +709 -0
- multimodal/build/lib/open_flamingo/train/train_utils.py +387 -0
- multimodal/open_flamingo.egg-info/PKG-INFO +247 -0
- multimodal/open_flamingo.egg-info/SOURCES.txt +53 -0
- multimodal/open_flamingo.egg-info/dependency_links.txt +1 -0
- multimodal/open_flamingo.egg-info/requires.txt +17 -0
- multimodal/open_flamingo.egg-info/top_level.txt +1 -0
app.py
CHANGED
@@ -53,7 +53,8 @@ flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transfor
|
|
53 |
)
|
54 |
|
55 |
|
56 |
-
checkpoint_path =
|
|
|
57 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
58 |
model_state_dict = {}
|
59 |
for key in checkpoint.keys():
|
@@ -326,7 +327,7 @@ with gr.Blocks() as demo:
|
|
326 |
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
327 |
queue=False)
|
328 |
|
329 |
-
demo.launch(enable_queue=True)
|
330 |
#
|
331 |
# with gr.Blocks() as demo:
|
332 |
# gr.Markdown(
|
|
|
53 |
)
|
54 |
|
55 |
|
56 |
+
checkpoint_path = "/home/aimos/huggingface/space/demo.pt"
|
57 |
+
# hf_hub_download("chendl/compositional_test", "pythiaS.pt")
|
58 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
59 |
model_state_dict = {}
|
60 |
for key in checkpoint.keys():
|
|
|
327 |
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
328 |
queue=False)
|
329 |
|
330 |
+
demo.launch(enable_queue=True,share=True)
|
331 |
#
|
332 |
# with gr.Blocks() as demo:
|
333 |
# gr.Markdown(
|
multimodal/build/lib/open_flamingo/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .src.flamingo import Flamingo
|
2 |
+
from .src.factory import create_model_and_transforms
|
multimodal/build/lib/open_flamingo/chat/__init__.py
ADDED
File without changes
|
multimodal/build/lib/open_flamingo/chat/conversation.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
import re
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
10 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
11 |
+
|
12 |
+
import dataclasses
|
13 |
+
from enum import auto, Enum
|
14 |
+
from typing import List, Tuple, Any
|
15 |
+
|
16 |
+
import string
|
17 |
+
import cv2
|
18 |
+
import gradio as gr
|
19 |
+
|
20 |
+
from huggingface_hub import hf_hub_download, login
|
21 |
+
|
22 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
23 |
+
from open_flamingo.eval.task.caption_chat import captioner
|
24 |
+
|
25 |
+
class SeparatorStyle(Enum):
|
26 |
+
"""Different separator style."""
|
27 |
+
SINGLE = auto()
|
28 |
+
TWO = auto()
|
29 |
+
|
30 |
+
|
31 |
+
@dataclasses.dataclass
|
32 |
+
class Conversation:
|
33 |
+
"""A class that keeps all conversation history."""
|
34 |
+
system: str
|
35 |
+
roles: List[str]
|
36 |
+
messages: List[List[str]]
|
37 |
+
offset: int
|
38 |
+
# system_img: List[Image.Image] = []
|
39 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
40 |
+
sep: str = "###"
|
41 |
+
sep2: str = None
|
42 |
+
|
43 |
+
skip_next: bool = False
|
44 |
+
conv_id: Any = None
|
45 |
+
|
46 |
+
def get_prompt(self):
|
47 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
48 |
+
ret = self.system + self.sep
|
49 |
+
for role, message in self.messages:
|
50 |
+
if message:
|
51 |
+
ret += role + ": " + message + self.sep
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
return ret
|
55 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
56 |
+
seps = [self.sep, self.sep2]
|
57 |
+
ret = self.system + seps[0]
|
58 |
+
for i, (role, message) in enumerate(self.messages):
|
59 |
+
if message:
|
60 |
+
ret += role + ": " + message + seps[i % 2]
|
61 |
+
else:
|
62 |
+
ret += role + ":"
|
63 |
+
return ret
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
66 |
+
|
67 |
+
def append_message(self, role, message):
|
68 |
+
self.messages.append([role, message])
|
69 |
+
|
70 |
+
def to_gradio_chatbot(self):
|
71 |
+
ret = []
|
72 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
73 |
+
if i % 2 == 0:
|
74 |
+
ret.append([msg, None])
|
75 |
+
else:
|
76 |
+
ret[-1][-1] = msg
|
77 |
+
return ret
|
78 |
+
|
79 |
+
def copy(self):
|
80 |
+
return Conversation(
|
81 |
+
system=self.system,
|
82 |
+
# system_img=self.system_img,
|
83 |
+
roles=self.roles,
|
84 |
+
messages=[[x, y] for x, y in self.messages],
|
85 |
+
offset=self.offset,
|
86 |
+
sep_style=self.sep_style,
|
87 |
+
sep=self.sep,
|
88 |
+
sep2=self.sep2,
|
89 |
+
conv_id=self.conv_id)
|
90 |
+
|
91 |
+
def dict(self):
|
92 |
+
return {
|
93 |
+
"system": self.system,
|
94 |
+
# "system_img": self.system_img,
|
95 |
+
"roles": self.roles,
|
96 |
+
"messages": self.messages,
|
97 |
+
"offset": self.offset,
|
98 |
+
"sep": self.sep,
|
99 |
+
"sep2": self.sep2,
|
100 |
+
"conv_id": self.conv_id,
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
105 |
+
|
106 |
+
def __init__(self, stops=[], encounters=1):
|
107 |
+
super().__init__()
|
108 |
+
self.stops = stops
|
109 |
+
|
110 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
111 |
+
for stop in self.stops:
|
112 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
113 |
+
return True
|
114 |
+
|
115 |
+
return False
|
116 |
+
|
117 |
+
|
118 |
+
CONV_VISION = Conversation(
|
119 |
+
system="Give the following image: <Img>ImageContent</Img>. "
|
120 |
+
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
121 |
+
roles=("Human", "Assistant"),
|
122 |
+
messages=[],
|
123 |
+
offset=2,
|
124 |
+
sep_style=SeparatorStyle.SINGLE,
|
125 |
+
sep="###",
|
126 |
+
)
|
127 |
+
|
128 |
+
def get_outputs(
|
129 |
+
model,
|
130 |
+
batch_images,
|
131 |
+
attention_mask,
|
132 |
+
max_generation_length,
|
133 |
+
min_generation_length,
|
134 |
+
num_beams,
|
135 |
+
length_penalty,
|
136 |
+
input_ids,
|
137 |
+
image_start_index_list=None,
|
138 |
+
image_nums=None,
|
139 |
+
bad_words_ids=None,
|
140 |
+
):
|
141 |
+
# and torch.cuda.amp.autocast(dtype=torch.float16)
|
142 |
+
with torch.inference_mode():
|
143 |
+
outputs = model(
|
144 |
+
vision_x=batch_images,
|
145 |
+
lang_x=input_ids,
|
146 |
+
attention_mask=attention_mask,
|
147 |
+
labels=None,
|
148 |
+
image_nums=image_nums,
|
149 |
+
image_start_index_list=image_start_index_list,
|
150 |
+
added_bbox_list=None,
|
151 |
+
add_box=False,
|
152 |
+
)
|
153 |
+
# outputs = model.generate(
|
154 |
+
# batch_images,
|
155 |
+
# input_ids,
|
156 |
+
# attention_mask=attention_mask,
|
157 |
+
# max_new_tokens=max_generation_length,
|
158 |
+
# min_length=min_generation_length,
|
159 |
+
# num_beams=num_beams,
|
160 |
+
# length_penalty=length_penalty,
|
161 |
+
# image_start_index_list=image_start_index_list,
|
162 |
+
# image_nums=image_nums,
|
163 |
+
# bad_words_ids=bad_words_ids,
|
164 |
+
# )
|
165 |
+
|
166 |
+
return outputs
|
167 |
+
|
168 |
+
def generate(
|
169 |
+
idx,
|
170 |
+
image,
|
171 |
+
text,
|
172 |
+
image_processor,
|
173 |
+
tokenizer,
|
174 |
+
flamingo,
|
175 |
+
vis_embed_size=256,
|
176 |
+
rank=0,
|
177 |
+
world_size=1,
|
178 |
+
):
|
179 |
+
if image is None:
|
180 |
+
raise gr.Error("Please upload an image.")
|
181 |
+
flamingo.eval()
|
182 |
+
loc_token_ids = []
|
183 |
+
for i in range(1000):
|
184 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
185 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
186 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
187 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
188 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
189 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
190 |
+
|
191 |
+
image_ori = image
|
192 |
+
image = image.convert("RGB")
|
193 |
+
width = image.width
|
194 |
+
height = image.height
|
195 |
+
image = image.resize((224, 224))
|
196 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
197 |
+
if idx == 1:
|
198 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
|
199 |
+
bad_words_ids = None
|
200 |
+
max_generation_length = 5
|
201 |
+
else:
|
202 |
+
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
|
203 |
+
bad_words_ids = loc_word_ids
|
204 |
+
max_generation_length = 300
|
205 |
+
encodings = tokenizer(
|
206 |
+
prompt,
|
207 |
+
padding="longest",
|
208 |
+
truncation=True,
|
209 |
+
return_tensors="pt",
|
210 |
+
max_length=2000,
|
211 |
+
)
|
212 |
+
input_ids = encodings["input_ids"]
|
213 |
+
attention_mask = encodings["attention_mask"]
|
214 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
215 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
216 |
+
image_nums = [1] * len(input_ids)
|
217 |
+
outputs = get_outputs(
|
218 |
+
model=flamingo,
|
219 |
+
batch_images=batch_images,
|
220 |
+
attention_mask=attention_mask,
|
221 |
+
max_generation_length=max_generation_length,
|
222 |
+
min_generation_length=4,
|
223 |
+
num_beams=1,
|
224 |
+
length_penalty=1.0,
|
225 |
+
input_ids=input_ids,
|
226 |
+
bad_words_ids=bad_words_ids,
|
227 |
+
image_start_index_list=image_start_index_list,
|
228 |
+
image_nums=image_nums,
|
229 |
+
)
|
230 |
+
|
231 |
+
boxes = outputs["boxes"]
|
232 |
+
scores = outputs["scores"]
|
233 |
+
if len(scores) > 0:
|
234 |
+
box = boxes[scores.argmax()]/224
|
235 |
+
print(f"{box}")
|
236 |
+
|
237 |
+
|
238 |
+
if len(boxes)>0:
|
239 |
+
open_cv_image = np.array(image_ori)
|
240 |
+
# Convert RGB to BGR
|
241 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
242 |
+
box = box*[width,height,width,height]
|
243 |
+
# for box in boxes:
|
244 |
+
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
245 |
+
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
|
246 |
+
return f"Output:{box}", out_image
|
247 |
+
else:
|
248 |
+
gen_text = tokenizer.batch_decode(outputs)
|
249 |
+
return (f"{gen_text}")
|
250 |
+
|
251 |
+
def preprocess_conv(data):
|
252 |
+
conversation = ""
|
253 |
+
BEGIN_SIGNAL = "### "
|
254 |
+
END_SIGNAL = "\n"
|
255 |
+
for idx, d in enumerate(data):
|
256 |
+
from_str = d["from"]
|
257 |
+
if from_str.lower() == "human":
|
258 |
+
from_str = "Human"
|
259 |
+
elif from_str.lower() == "gpt":
|
260 |
+
from_str = "Assistant"
|
261 |
+
else:
|
262 |
+
from_str = 'unknown'
|
263 |
+
conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
|
264 |
+
return conversation
|
265 |
+
|
266 |
+
def preprocess_image(sample, image_processor):
|
267 |
+
image = image_processor(sample)
|
268 |
+
if isinstance(image, transformers.image_processing_utils.BatchFeature):
|
269 |
+
image = torch.tensor(image["pixel_values"][0])
|
270 |
+
return image
|
271 |
+
|
272 |
+
class Chat:
|
273 |
+
def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
|
274 |
+
self.model = model
|
275 |
+
self.vis_processor = vis_processor
|
276 |
+
self.tokenizer = tokenizer
|
277 |
+
self.vis_embed_size = vis_embed_size
|
278 |
+
self.conv = []
|
279 |
+
# stop_words_ids = [torch.tensor([835]).to(self.device),
|
280 |
+
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
281 |
+
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
282 |
+
|
283 |
+
def ask(self, text, conv,radio):
|
284 |
+
if radio in ["Cap"]:
|
285 |
+
conv.append({
|
286 |
+
"from": "human",
|
287 |
+
"value": "",
|
288 |
+
})
|
289 |
+
elif radio in ["VQA"]:
|
290 |
+
conv.append({
|
291 |
+
"from": "human",
|
292 |
+
"value": f"Answer the question using a single word or phrase. {text}",
|
293 |
+
})
|
294 |
+
elif radio in ["REC"]:
|
295 |
+
conv.append({
|
296 |
+
"from": "human",
|
297 |
+
"value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
|
298 |
+
})
|
299 |
+
else:
|
300 |
+
conv.append({
|
301 |
+
"from": "human",
|
302 |
+
"value": text,
|
303 |
+
})
|
304 |
+
# if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
305 |
+
# and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
306 |
+
# conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
307 |
+
# else:
|
308 |
+
# conv.append_message(conv.roles[0], text)
|
309 |
+
|
310 |
+
def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
|
311 |
+
repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
|
312 |
+
# conv.append_message(conv.roles[1], None)
|
313 |
+
# embs = self.get_context_emb(conv, img_list)
|
314 |
+
#
|
315 |
+
# # current_max_len = embs.shape[1] + max_new_tokens + 100
|
316 |
+
# # begin_idx = max(0, current_max_len - max_length)
|
317 |
+
# # embs = embs[:, begin_idx:]
|
318 |
+
# outputs = self.model.llama_model.generate(
|
319 |
+
# inputs_embeds=embs,
|
320 |
+
# max_new_tokens=max_new_tokens,
|
321 |
+
# stopping_criteria=self.stopping_criteria,
|
322 |
+
# num_beams=num_beams,
|
323 |
+
# min_length=min_length,
|
324 |
+
# top_p=top_p,
|
325 |
+
# repetition_penalty=repetition_penalty,
|
326 |
+
# length_penalty=length_penalty,
|
327 |
+
# temperature=temperature,
|
328 |
+
# )
|
329 |
+
# output_token = outputs[0]
|
330 |
+
# if output_token[0] == 0:
|
331 |
+
# output_token = output_token[1:]
|
332 |
+
# output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
333 |
+
# output_text = output_text.split('###')[0] # remove the stop sign '###'
|
334 |
+
# output_text = output_text.split('Assistant:')[-1].strip()
|
335 |
+
# conv.messages[-1][1] = output_text
|
336 |
+
visual_token = "<|#visual#|>"
|
337 |
+
previsual_token = "<|#previsual#|>"
|
338 |
+
box_token = "<|#box#|>"
|
339 |
+
prebox_token = "<|#prebox#|>"
|
340 |
+
end_token = "<|#endofobject#|>"
|
341 |
+
object_token = "<|#object#|>"
|
342 |
+
end_of_attr_token = "<|#endofattr#|>"
|
343 |
+
preend_of_attr_token = "<|#preendofattr#|>"
|
344 |
+
media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
345 |
+
box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
346 |
+
endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
347 |
+
endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
348 |
+
endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
349 |
+
visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
350 |
+
previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
351 |
+
prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
352 |
+
size = 224
|
353 |
+
self.model.eval()
|
354 |
+
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
|
355 |
+
# image_path = input("Please enter the image path: ")
|
356 |
+
image = img_list[0].convert("RGB")
|
357 |
+
image_ori = image
|
358 |
+
image = image.resize((size, size))
|
359 |
+
print(f"image size: {image.size}")
|
360 |
+
batch_images = preprocess_image(image, self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
361 |
+
|
362 |
+
# conversation = []
|
363 |
+
human_sentence = None
|
364 |
+
if radio in ["Cap","VQA"]:
|
365 |
+
conv.append({
|
366 |
+
"from": "gpt",
|
367 |
+
"value": "",
|
368 |
+
})
|
369 |
+
elif radio in ["REC"]:
|
370 |
+
conv.append(
|
371 |
+
{
|
372 |
+
"from": "gpt",
|
373 |
+
"value": object_token + text_input + end_token + visual_token,
|
374 |
+
}
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
conv.append({
|
378 |
+
"from": "gpt",
|
379 |
+
"value": "",
|
380 |
+
})
|
381 |
+
# while True:
|
382 |
+
# human_sentence = input("### Human: ")
|
383 |
+
# if human_sentence == "#end#":
|
384 |
+
# break
|
385 |
+
# conversation.append({
|
386 |
+
# "from": "human",
|
387 |
+
# "value": human_sentence,
|
388 |
+
# })
|
389 |
+
# conversation.append({
|
390 |
+
# "from": "gpt",
|
391 |
+
# "value": "",
|
392 |
+
# })
|
393 |
+
text = preprocess_conv(conv).strip()
|
394 |
+
caption = f"<|#image#|>{self.tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
|
395 |
+
encodings = self.tokenizer(
|
396 |
+
caption,
|
397 |
+
padding="longest",
|
398 |
+
truncation=True,
|
399 |
+
return_tensors="pt",
|
400 |
+
max_length=2000,
|
401 |
+
)
|
402 |
+
input_ids = encodings["input_ids"]
|
403 |
+
attention_mask = encodings["attention_mask"]
|
404 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
405 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
406 |
+
image_nums = [1] * len(input_ids)
|
407 |
+
added_bbox_list = []
|
408 |
+
if radio in ["Cap"]:
|
409 |
+
output_text, out_image = captioner(self.model,self.tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list)
|
410 |
+
else:
|
411 |
+
with torch.inference_mode():
|
412 |
+
text_outputs = self.model.generate(
|
413 |
+
batch_images,
|
414 |
+
input_ids,
|
415 |
+
attention_mask=attention_mask,
|
416 |
+
max_new_tokens=20,
|
417 |
+
# min_new_tokens=8,
|
418 |
+
num_beams=1,
|
419 |
+
# length_penalty=0,
|
420 |
+
image_start_index_list=image_start_index_list,
|
421 |
+
image_nums=image_nums,
|
422 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
423 |
+
)
|
424 |
+
# and torch.cuda.amp.autocast(dtype=torch.float16)
|
425 |
+
with torch.no_grad():
|
426 |
+
outputs = self.model(
|
427 |
+
vision_x=batch_images,
|
428 |
+
lang_x=input_ids,
|
429 |
+
attention_mask=attention_mask,
|
430 |
+
image_nums=image_nums,
|
431 |
+
image_start_index_list=image_start_index_list,
|
432 |
+
added_bbox_list=None,
|
433 |
+
add_box=False,
|
434 |
+
)
|
435 |
+
boxes = outputs["boxes"]
|
436 |
+
scores = outputs["scores"]
|
437 |
+
if len(scores) > 0:
|
438 |
+
box = boxes[scores.argmax()] / 224
|
439 |
+
print(f"{box}")
|
440 |
+
out_image = None
|
441 |
+
|
442 |
+
if len(boxes)>0:
|
443 |
+
width, height = image_ori.size
|
444 |
+
open_cv_image = np.array(image_ori)
|
445 |
+
# Convert RGB to BGR
|
446 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
447 |
+
box = box * [width, height, width, height]
|
448 |
+
# for box in boxes:
|
449 |
+
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
450 |
+
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
|
451 |
+
|
452 |
+
|
453 |
+
# output_token = outputs[0, input_ids.shape[1]:]
|
454 |
+
# output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
|
455 |
+
# conv[-1]["value"] = output_text
|
456 |
+
# # conv.messages[-1][1] = output_text
|
457 |
+
# print(
|
458 |
+
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
459 |
+
output_text = self.tokenizer.decode(text_outputs[0])
|
460 |
+
print(output_text)
|
461 |
+
output_text = re.findall(r'Assistant:(.+)', output_text)[-1]
|
462 |
+
print(output_text)
|
463 |
+
|
464 |
+
return output_text, out_image
|
465 |
+
|
466 |
+
def upload_img(self, image, conv, img_list):
|
467 |
+
img_list.append(image)
|
468 |
+
# if isinstance(image, str): # is a image path
|
469 |
+
# raw_image = Image.open(image).convert('RGB')
|
470 |
+
# image = image.resize((224, 224))
|
471 |
+
# image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
472 |
+
# elif isinstance(image, Image.Image):
|
473 |
+
# raw_image = image
|
474 |
+
# image = image.resize((224, 224))
|
475 |
+
# image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
476 |
+
# elif isinstance(image, torch.Tensor):
|
477 |
+
# if len(image.shape) == 3:
|
478 |
+
# image = image.unsqueeze(0)
|
479 |
+
# # image = image.to(self.device)
|
480 |
+
#
|
481 |
+
# # image_emb, _ = self.model.encode_img(image)
|
482 |
+
# img_list.append(image_emb)
|
483 |
+
# conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
484 |
+
msg = "Received."
|
485 |
+
# self.conv.append_message(self.conv.roles[1], msg)
|
486 |
+
return msg
|
487 |
+
|
488 |
+
# def get_context_emb(self, conv, img_list):
|
489 |
+
# prompt = conv.get_prompt()
|
490 |
+
# prompt_segs = prompt.split('<ImageHere>')
|
491 |
+
# assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
492 |
+
# seg_tokens = [
|
493 |
+
# self.model.llama_tokenizer(
|
494 |
+
# seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
495 |
+
# # only add bos to the first seg
|
496 |
+
# for i, seg in enumerate(prompt_segs)
|
497 |
+
# ]
|
498 |
+
# seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
499 |
+
# mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
500 |
+
# mixed_embs = torch.cat(mixed_embs, dim=1)
|
501 |
+
# return mixed_embs
|
502 |
+
|
503 |
+
def evaluate_exp(
|
504 |
+
model,
|
505 |
+
tokenizer,
|
506 |
+
image_processor,
|
507 |
+
vis_embed_size=None,
|
508 |
+
rank=0,
|
509 |
+
world_size=1,
|
510 |
+
id=0,
|
511 |
+
add_visual=True,
|
512 |
+
):
|
513 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
514 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
515 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
516 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
517 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
518 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
519 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
520 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
521 |
+
size = image_processor.size["shortest_edge"]
|
522 |
+
model.eval()
|
523 |
+
# "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
|
524 |
+
image_path = input("Please enter the image path: ")
|
525 |
+
image = Image.open(image_path).convert("RGB")
|
526 |
+
image = image.resize((size, size))
|
527 |
+
print(f"image size: {image.size}")
|
528 |
+
batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
529 |
+
conversation = []
|
530 |
+
human_sentence = None
|
531 |
+
while True:
|
532 |
+
human_sentence = input("### Human: ")
|
533 |
+
if human_sentence == "#end#":
|
534 |
+
break
|
535 |
+
conversation.append({
|
536 |
+
"from": "human",
|
537 |
+
"value": human_sentence,
|
538 |
+
})
|
539 |
+
conversation.append({
|
540 |
+
"from": "gpt",
|
541 |
+
"value": "",
|
542 |
+
})
|
543 |
+
text = preprocess_conv(conversation).strip()
|
544 |
+
caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
|
545 |
+
encodings = tokenizer(
|
546 |
+
caption,
|
547 |
+
padding="longest",
|
548 |
+
truncation=True,
|
549 |
+
return_tensors="pt",
|
550 |
+
max_length=2000,
|
551 |
+
)
|
552 |
+
input_ids = encodings["input_ids"].to("cuda")
|
553 |
+
attention_mask = encodings["attention_mask"].to("cuda")
|
554 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
555 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
556 |
+
image_nums = [1] * len(input_ids)
|
557 |
+
with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
|
558 |
+
outputs = model.generate(
|
559 |
+
batch_images,
|
560 |
+
input_ids,
|
561 |
+
attention_mask=attention_mask,
|
562 |
+
max_new_tokens=100,
|
563 |
+
# min_new_tokens=8,
|
564 |
+
num_beams=1,
|
565 |
+
image_start_index_list=image_start_index_list,
|
566 |
+
image_nums=image_nums,
|
567 |
+
)
|
568 |
+
print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
569 |
+
|
570 |
+
|
571 |
+
|
multimodal/build/lib/open_flamingo/eval/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
multimodal/build/lib/open_flamingo/eval/classification.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Sequence, Tuple
|
2 |
+
import re
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def postprocess_classification_generation(predictions) -> str:
|
8 |
+
return re.split("Prompt|Completion", predictions, 1)[0]
|
9 |
+
|
10 |
+
|
11 |
+
def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
|
12 |
+
"""Compute the accuracy of a sequence of predictions."""
|
13 |
+
|
14 |
+
def _preprocess_fn(s):
|
15 |
+
"""Function to preprocess both targets and predictions."""
|
16 |
+
return s.lower()
|
17 |
+
|
18 |
+
is_correct = [
|
19 |
+
_preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
|
20 |
+
for x in predictions
|
21 |
+
]
|
22 |
+
|
23 |
+
return np.mean(is_correct).item()
|
24 |
+
|
25 |
+
|
26 |
+
def compute_shifted_logits_and_labels(
|
27 |
+
logits: torch.Tensor, encodings, tokenizer, eoc_token_id
|
28 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
29 |
+
"""Helper function to compute shifted logits and labels.
|
30 |
+
|
31 |
+
This allows for straightforward computation of the loss on shift_logits
|
32 |
+
and shift_labels such that the nth element of logits computes the n-1th
|
33 |
+
element of the original labels (in the outputs, the nth element of logits
|
34 |
+
corresponds to the nth element of the labels).
|
35 |
+
|
36 |
+
Elements in shift_labels that correspond to inputs are masked with values
|
37 |
+
of -100 (by default in hf, loss is only computed on token IDs >= 0).
|
38 |
+
|
39 |
+
Returns: tuple containing two elements:
|
40 |
+
shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
|
41 |
+
shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
|
42 |
+
"""
|
43 |
+
|
44 |
+
labels = encodings["input_ids"].clone()
|
45 |
+
|
46 |
+
# convert padding and EOC tokens to -100 so they are ignored in loss
|
47 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
48 |
+
labels[labels == eoc_token_id] = -100
|
49 |
+
|
50 |
+
# Convert all tokens in prefix until separator to -100 so they are
|
51 |
+
# ignored in loss
|
52 |
+
for idx in range(len(labels)):
|
53 |
+
# Find the location of the last token of prefix *from right*,
|
54 |
+
# since the first non-padding token of the sequence will also be
|
55 |
+
# eos_token (because bos_token and eos_token are the same for
|
56 |
+
# the tokenizer).
|
57 |
+
end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
|
58 |
+
labels[idx, : end_of_prefix + 1] = -100
|
59 |
+
|
60 |
+
# Shift so that tokens < n predict n. The shifted tensors both have
|
61 |
+
# shape [batch_size, seq_len - 1].
|
62 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
63 |
+
shift_labels = labels[..., 1:].contiguous()
|
64 |
+
|
65 |
+
return shift_logits, shift_labels
|
66 |
+
|
67 |
+
|
68 |
+
def compute_per_sample_probs(
|
69 |
+
encodings, tokenizer, logits: torch.Tensor, eoc_token_id
|
70 |
+
) -> torch.Tensor:
|
71 |
+
"""Helper function to compute per-sample probability of the input sequence.
|
72 |
+
|
73 |
+
Assumes <eos token> is used to separate inputs from targets in the
|
74 |
+
prompt text
|
75 |
+
"""
|
76 |
+
shift_logits, shift_labels = compute_shifted_logits_and_labels(
|
77 |
+
logits, encodings, tokenizer, eoc_token_id
|
78 |
+
)
|
79 |
+
|
80 |
+
# Tuple of tensors for unmasked label tokens. The first element of the
|
81 |
+
# tuple contains the batch indices; the second element contains the
|
82 |
+
# sequence indices.
|
83 |
+
unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
|
84 |
+
# Tensor where the i^th element is the token_id corresponding to the i^th
|
85 |
+
# element of unmasked_indices
|
86 |
+
unmasked_token_ids = shift_labels[unmasked_indices]
|
87 |
+
|
88 |
+
# 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
|
89 |
+
target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
|
90 |
+
target_idxs = target_idxs.to(shift_logits.device)
|
91 |
+
|
92 |
+
# Sanity check that every element in batch has at least one unmasked
|
93 |
+
# target token
|
94 |
+
assert torch.all(
|
95 |
+
torch.bincount(target_idxs[:, 0]) != 0
|
96 |
+
), "At least one element in batch has no unmasked target tokens."
|
97 |
+
|
98 |
+
# Renormalize over tokens to make sure they are proper probabilities via
|
99 |
+
# softmax over the token dimension.
|
100 |
+
shift_probs = torch.nn.functional.softmax(shift_logits, 2)
|
101 |
+
|
102 |
+
# Compute the probability of the target sequence (as the product of the
|
103 |
+
# probability of the individual tokens in the sequence).
|
104 |
+
target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
|
105 |
+
for i, j, k in target_idxs:
|
106 |
+
target_probs[i] *= shift_probs[i, j, k]
|
107 |
+
|
108 |
+
return target_probs
|
109 |
+
|
110 |
+
|
111 |
+
def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
|
112 |
+
"""Helper function to compute per-sample classification loss.
|
113 |
+
|
114 |
+
Assumes <eos token> is used to separate inputs from targets in the
|
115 |
+
prompt text
|
116 |
+
"""
|
117 |
+
shift_logits, shift_labels = compute_shifted_logits_and_labels(
|
118 |
+
logits, encodings, tokenizer, eoc_token_id
|
119 |
+
)
|
120 |
+
|
121 |
+
device = shift_logits.device
|
122 |
+
|
123 |
+
# Loss is computed token-wise, on Tensors of shape
|
124 |
+
# [batch_size * (seq_len - 1), vocab_size]
|
125 |
+
# and returns a loss tensor of shape
|
126 |
+
# [batch_size * (seq_len - 1)]. Most of the tokens will be masked
|
127 |
+
# in this computation.
|
128 |
+
loss = torch.nn.functional.cross_entropy(
|
129 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
130 |
+
shift_labels.view(-1).to(device),
|
131 |
+
reduction="none",
|
132 |
+
)
|
133 |
+
|
134 |
+
# Reshape to [batch_size, seq_len - 1]
|
135 |
+
loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()
|
136 |
+
|
137 |
+
# loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
|
138 |
+
# that should be ignored in the loss.
|
139 |
+
loss_mask = (shift_labels != -100).int().cpu()
|
140 |
+
|
141 |
+
loss *= loss_mask
|
142 |
+
|
143 |
+
# Compute per-element loss : sum loss over all (unmasked) tokens and
|
144 |
+
# divide by number of variable tokens to obtain tensor of
|
145 |
+
# shape [batch_size,]
|
146 |
+
loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
|
147 |
+
return loss
|
multimodal/build/lib/open_flamingo/eval/coco_metric.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pycocoevalcap.eval import COCOEvalCap
|
2 |
+
from pycocotools.coco import COCO
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
def compute_cider(
|
7 |
+
result_path,
|
8 |
+
annotations_path,
|
9 |
+
):
|
10 |
+
# create coco object and coco_result object
|
11 |
+
coco = COCO(annotations_path)
|
12 |
+
coco_result = coco.loadRes(result_path)
|
13 |
+
|
14 |
+
# create coco_eval object by taking coco and coco_result
|
15 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
16 |
+
coco_eval.params["image_id"] = coco_result.getImgIds()
|
17 |
+
coco_eval.evaluate()
|
18 |
+
|
19 |
+
return coco_eval.eval
|
20 |
+
|
21 |
+
|
22 |
+
def postprocess_captioning_generation(predictions):
|
23 |
+
return predictions
|
multimodal/build/lib/open_flamingo/eval/dataset_zoo/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .aro_datasets import VG_Relation, VG_Attribution, COCO_Order, Flickr30k_Order
|
2 |
+
from .retrieval import COCO_Retrieval, Flickr30k_Retrieval
|
3 |
+
|
4 |
+
|
5 |
+
def get_dataset(dataset_name, image_preprocess=None, text_perturb_fn=None, image_perturb_fn=None, download=False, *args, **kwargs):
|
6 |
+
"""
|
7 |
+
Helper function that returns a dataset object with an evaluation function.
|
8 |
+
dataset_name: Name of the dataset.
|
9 |
+
image_preprocess: Preprocessing function for images.
|
10 |
+
text_perturb_fn: A function that takes in a string and returns a string. This is for perturbation experiments.
|
11 |
+
image_perturb_fn: A function that takes in a PIL image and returns a PIL image. This is for perturbation experiments.
|
12 |
+
download: Whether to allow downloading images if they are not found.
|
13 |
+
"""
|
14 |
+
if dataset_name == "VG_Relation":
|
15 |
+
from .aro_datasets import get_visual_genome_relation
|
16 |
+
return get_visual_genome_relation(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
|
17 |
+
elif dataset_name == "VG_Attribution":
|
18 |
+
from .aro_datasets import get_visual_genome_attribution
|
19 |
+
return get_visual_genome_attribution(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
|
20 |
+
elif dataset_name == "COCO_Order":
|
21 |
+
from .aro_datasets import get_coco_order
|
22 |
+
return get_coco_order(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
|
23 |
+
elif dataset_name == "Flickr30k_Order":
|
24 |
+
from .aro_datasets import get_flickr30k_order
|
25 |
+
return get_flickr30k_order(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
|
26 |
+
elif dataset_name == "COCO_Retrieval":
|
27 |
+
from .retrieval import get_coco_retrieval
|
28 |
+
return get_coco_retrieval(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
|
29 |
+
elif dataset_name == "Flickr30k_Retrieval":
|
30 |
+
from .retrieval import get_flickr30k_retrieval
|
31 |
+
return get_flickr30k_retrieval(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unknown dataset {dataset_name}")
|
multimodal/build/lib/open_flamingo/eval/dataset_zoo/aro_datasets.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from easydict import EasyDict as edict
|
11 |
+
from torchvision.datasets.utils import download_url
|
12 |
+
|
13 |
+
from .perturbations import TextShuffler
|
14 |
+
from .constants import ARO_ROOT, COCO_ROOT, FLICKR_ROOT
|
15 |
+
from .retrieval import pre_caption
|
16 |
+
|
17 |
+
|
18 |
+
class VG_Relation(Dataset):
|
19 |
+
def __init__(self, image_preprocess, text_perturb_fn=None, image_perturb_fn=None, root_dir=ARO_ROOT, download=False):
|
20 |
+
'''
|
21 |
+
image_preprocess: a function that takes in a PIL image and returns a tensor.
|
22 |
+
text_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
|
23 |
+
image_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
|
24 |
+
root_dir: Directory for the VG-R dataset.
|
25 |
+
download: Whether to download the dataset if it does not exist.
|
26 |
+
'''
|
27 |
+
self.root_dir = root_dir
|
28 |
+
annotation_file = os.path.join(root_dir, "visual_genome_relation.json")
|
29 |
+
image_dir = os.path.join(root_dir, "images")
|
30 |
+
if not os.path.exists(image_dir):
|
31 |
+
print("Image Directory for VG_Relation could not be found!")
|
32 |
+
if download:
|
33 |
+
self.download()
|
34 |
+
else:
|
35 |
+
raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
|
36 |
+
|
37 |
+
if not os.path.exists(annotation_file):
|
38 |
+
subprocess.call(["gdown", "--id", "1kX2iCHEv0CADL8dSO1nMdW-V0NqIAiP3", "--output", annotation_file])
|
39 |
+
|
40 |
+
with open(annotation_file, "r") as f:
|
41 |
+
self.dataset = json.load(f)
|
42 |
+
|
43 |
+
self.all_relations = list()
|
44 |
+
for item in self.dataset:
|
45 |
+
item["image_path"] = os.path.join(image_dir, item["image_path"])
|
46 |
+
self.all_relations.append(item["relation_name"])
|
47 |
+
|
48 |
+
self.image_preprocess = image_preprocess
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.dataset)
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
test_case = self.dataset[index]
|
55 |
+
image = Image.open(test_case["image_path"]).convert('RGB')
|
56 |
+
# Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
|
57 |
+
image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))
|
58 |
+
|
59 |
+
if self.image_preprocess is not None:
|
60 |
+
image = self.image_preprocess(image)
|
61 |
+
|
62 |
+
# Each test case has a correct and incorrect caption.
|
63 |
+
true_caption = test_case["true_caption"]
|
64 |
+
false_caption = test_case["false_caption"]
|
65 |
+
item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
|
66 |
+
return item
|
67 |
+
|
68 |
+
def download(self):
|
69 |
+
os.makedirs(self.root_dir, exist_ok=True)
|
70 |
+
image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
|
71 |
+
subprocess.call(["gdown", "--no-cookies", "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
|
72 |
+
subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)
|
73 |
+
|
74 |
+
|
75 |
+
def evaluate_scores(self, scores):
|
76 |
+
"""
|
77 |
+
Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
|
78 |
+
"""
|
79 |
+
if isinstance(scores, tuple):
|
80 |
+
scores_i2t = scores[1]
|
81 |
+
scores_t2i = scores[0]
|
82 |
+
else:
|
83 |
+
scores_t2i = scores
|
84 |
+
scores_i2t = scores
|
85 |
+
|
86 |
+
metrics = {"Accuracy": None}
|
87 |
+
preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
|
88 |
+
correct_mask = (preds == 1)
|
89 |
+
metrics["Accuracy"] = np.mean(correct_mask)
|
90 |
+
|
91 |
+
all_relations = np.array(self.all_relations)
|
92 |
+
|
93 |
+
result_records = []
|
94 |
+
# Log the accuracy of all relations
|
95 |
+
for relation in np.unique(all_relations):
|
96 |
+
relation_mask = (all_relations == relation)
|
97 |
+
if relation_mask.sum() == 0:
|
98 |
+
continue
|
99 |
+
result_records.append({
|
100 |
+
"Relation": relation,
|
101 |
+
"Accuracy": correct_mask[relation_mask].mean(),
|
102 |
+
"Count": relation_mask.sum(),
|
103 |
+
"Dataset": "Visual Genome Relation"
|
104 |
+
})
|
105 |
+
return result_records
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
class VG_Attribution(Dataset):
|
110 |
+
def __init__(self, image_preprocess, text_perturb_fn=None, image_perturb_fn=None, root_dir=ARO_ROOT, download=False):
|
111 |
+
'''
|
112 |
+
image_preprocess: a function that takes in a PIL image and returns a tensor.
|
113 |
+
text_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
|
114 |
+
image_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
|
115 |
+
root_dir: Directory for the VG-A dataset.
|
116 |
+
'''
|
117 |
+
self.root_dir = root_dir
|
118 |
+
annotation_file = os.path.join(root_dir, "visual_genome_attribution.json")
|
119 |
+
image_dir = os.path.join(root_dir, "images")
|
120 |
+
if not os.path.exists(image_dir):
|
121 |
+
print("Image Directory for VG_Attribution could not be found!")
|
122 |
+
if download:
|
123 |
+
self.download()
|
124 |
+
else:
|
125 |
+
raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
|
126 |
+
|
127 |
+
|
128 |
+
if not os.path.exists(annotation_file):
|
129 |
+
subprocess.call(["gdown", "--id", "13tWvOrNOLHxl3Rm9cR3geAdHx2qR3-Tw", "--output", annotation_file])
|
130 |
+
|
131 |
+
with open(annotation_file, "r") as f:
|
132 |
+
self.dataset = json.load(f)
|
133 |
+
|
134 |
+
for item in self.dataset:
|
135 |
+
item["image_path"] = os.path.join(image_dir, item["image_path"])
|
136 |
+
|
137 |
+
# Set of attributes in each test case
|
138 |
+
self.all_attributes = [f"{item['attributes'][0]}_{item['attributes'][1]}" for item in self.dataset]
|
139 |
+
self.image_preprocess = image_preprocess
|
140 |
+
|
141 |
+
def __len__(self):
|
142 |
+
return len(self.dataset)
|
143 |
+
|
144 |
+
def __getitem__(self, index):
|
145 |
+
test_case = self.dataset[index]
|
146 |
+
image = Image.open(test_case["image_path"]).convert('RGB')
|
147 |
+
# Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
|
148 |
+
image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))
|
149 |
+
|
150 |
+
if self.image_preprocess is not None:
|
151 |
+
image = self.image_preprocess(image)
|
152 |
+
|
153 |
+
# Each test case has a correct and incorrect caption.
|
154 |
+
true_caption = test_case["true_caption"]
|
155 |
+
false_caption = test_case["false_caption"]
|
156 |
+
item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
|
157 |
+
return item
|
158 |
+
|
159 |
+
def download(self):
|
160 |
+
os.makedirs(self.root_dir, exist_ok=True)
|
161 |
+
image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
|
162 |
+
subprocess.call(["gdown", "--no-cookies", "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
|
163 |
+
subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)
|
164 |
+
|
165 |
+
|
166 |
+
def evaluate_scores(self, scores):
|
167 |
+
"""
|
168 |
+
Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
|
169 |
+
"""
|
170 |
+
if isinstance(scores, tuple):
|
171 |
+
scores_i2t = scores[1]
|
172 |
+
scores_t2i = scores[0]
|
173 |
+
else:
|
174 |
+
scores_t2i = scores
|
175 |
+
scores_i2t = scores
|
176 |
+
|
177 |
+
preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
|
178 |
+
correct_mask = (preds == 1)
|
179 |
+
result_records = []
|
180 |
+
all_attributes = np.array(self.all_attributes)
|
181 |
+
for attr in np.unique(all_attributes):
|
182 |
+
attr_mask = (all_attributes == attr)
|
183 |
+
if attr_mask.sum() < 25:
|
184 |
+
continue
|
185 |
+
result_records.append({
|
186 |
+
"Attributes": attr,
|
187 |
+
"Accuracy": correct_mask[attr_mask].mean(),
|
188 |
+
"Count": attr_mask.sum(),
|
189 |
+
"Dataset": "Visual Genome Attribution"
|
190 |
+
})
|
191 |
+
return result_records
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
class COCO_Order(Dataset):
|
197 |
+
def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test",
|
198 |
+
image_perturb_fn=None, download=False):
|
199 |
+
"""
|
200 |
+
COCO Order Dataset.
|
201 |
+
image_preprocess: image preprocessing function
|
202 |
+
root_dir: The directory of the coco dataset. This directory should contain test2014 files.
|
203 |
+
max_words: Cropping the caption to max_words.
|
204 |
+
split: 'val' or 'test'
|
205 |
+
image_perturb_fn: not used; for compatibility.
|
206 |
+
download: Whether to download the dataset if it does not exist.
|
207 |
+
"""
|
208 |
+
shuffler = TextShuffler()
|
209 |
+
perturb_functions = [shuffler.shuffle_nouns_and_adj, shuffler.shuffle_allbut_nouns_and_adj,
|
210 |
+
shuffler.shuffle_within_trigrams, shuffler.shuffle_trigrams]
|
211 |
+
|
212 |
+
self.root_dir = root_dir
|
213 |
+
if not os.path.exists(root_dir):
|
214 |
+
print("Directory for COCO could not be found!")
|
215 |
+
if download:
|
216 |
+
print("Downloading COCO now.")
|
217 |
+
self.download()
|
218 |
+
else:
|
219 |
+
raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
|
220 |
+
|
221 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
222 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
223 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
224 |
+
download_url(urls[split],root_dir)
|
225 |
+
|
226 |
+
self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
|
227 |
+
self.image_preprocess = image_preprocess
|
228 |
+
self.image_root = root_dir
|
229 |
+
|
230 |
+
self.test_cases = []
|
231 |
+
|
232 |
+
for img_id, ann in tqdm(enumerate(self.annotation)):
|
233 |
+
for i, caption in enumerate(ann['caption']):
|
234 |
+
test_case = {}
|
235 |
+
test_case["image"] = ann["image"]
|
236 |
+
test_case["caption_options"] = [pre_caption(caption,max_words)]
|
237 |
+
|
238 |
+
for perturb_fn in perturb_functions:
|
239 |
+
test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
|
240 |
+
self.test_cases.append(test_case)
|
241 |
+
|
242 |
+
def __len__(self):
|
243 |
+
return len(self.test_cases)
|
244 |
+
|
245 |
+
def __getitem__(self, index):
|
246 |
+
test_case = self.test_cases[index]
|
247 |
+
image_path = os.path.join(self.image_root, test_case["image"])
|
248 |
+
|
249 |
+
image = Image.open(image_path).convert('RGB')
|
250 |
+
if self.image_preprocess is not None:
|
251 |
+
image = self.image_preprocess(image)
|
252 |
+
|
253 |
+
item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
|
254 |
+
return item
|
255 |
+
|
256 |
+
def download(self):
|
257 |
+
import subprocess
|
258 |
+
os.makedirs(self.root_dir, exist_ok=True)
|
259 |
+
#subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
|
260 |
+
#subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
|
261 |
+
|
262 |
+
subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
|
263 |
+
subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
|
264 |
+
|
265 |
+
subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
|
266 |
+
subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
|
267 |
+
|
268 |
+
|
269 |
+
def evaluate_scores(self, scores):
|
270 |
+
if isinstance(scores, tuple):
|
271 |
+
scores_i2t = scores[0]
|
272 |
+
scores_t2i = scores[1].T # Make it N_ims x N_text
|
273 |
+
|
274 |
+
else:
|
275 |
+
scores_t2i = scores
|
276 |
+
scores_i2t = scores
|
277 |
+
|
278 |
+
preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
|
279 |
+
correct_mask = (preds == 0)
|
280 |
+
records = [{"Precision@1": np.mean(correct_mask)}]
|
281 |
+
return records
|
282 |
+
|
283 |
+
|
284 |
+
class Flickr30k_Order(Dataset):
|
285 |
+
def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30,
|
286 |
+
*args, **kwargs):
|
287 |
+
"""
|
288 |
+
image_preprocess: image preprocessing function
|
289 |
+
split: 'val' or 'test'
|
290 |
+
root_dir: The directory of the flickr30k images. This should contain the `flickr30k-images` directory that \
|
291 |
+
contains all the images.
|
292 |
+
"""
|
293 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
|
294 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
|
295 |
+
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
|
296 |
+
if not os.path.exists(root_dir):
|
297 |
+
print("Directory for Flickr30k could not be found!")
|
298 |
+
flickr_url = "https://forms.illinois.edu/sec/229675"
|
299 |
+
raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
|
300 |
+
|
301 |
+
download_url(urls[split],root_dir)
|
302 |
+
|
303 |
+
self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
|
304 |
+
self.image_preprocess = image_preprocess
|
305 |
+
self.root_dir = root_dir
|
306 |
+
|
307 |
+
self.test_cases = []
|
308 |
+
|
309 |
+
shuffler = TextShuffler()
|
310 |
+
perturb_functions = [shuffler.shuffle_nouns_and_adj, shuffler.shuffle_allbut_nouns_and_adj,
|
311 |
+
shuffler.shuffle_within_trigrams, shuffler.shuffle_trigrams]
|
312 |
+
for img_id, ann in tqdm(enumerate(self.annotation)):
|
313 |
+
for i, caption in enumerate(ann['caption']):
|
314 |
+
test_case = {}
|
315 |
+
test_case["image"] = ann["image"]
|
316 |
+
test_case["caption_options"] = [pre_caption(caption,max_words)]
|
317 |
+
|
318 |
+
for perturb_fn in perturb_functions:
|
319 |
+
test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
|
320 |
+
self.test_cases.append(test_case)
|
321 |
+
|
322 |
+
def __len__(self):
|
323 |
+
return len(self.test_cases)
|
324 |
+
|
325 |
+
def __getitem__(self, index):
|
326 |
+
test_case = self.test_cases[index]
|
327 |
+
image_path = os.path.join(self.root_dir, test_case["image"])
|
328 |
+
image = Image.open(image_path).convert('RGB')
|
329 |
+
|
330 |
+
if self.image_preprocess is not None:
|
331 |
+
image = self.image_preprocess(image)
|
332 |
+
|
333 |
+
item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
|
334 |
+
return item
|
335 |
+
|
336 |
+
def evaluate_scores(self, scores):
|
337 |
+
if isinstance(scores, tuple):
|
338 |
+
scores_i2t = scores[0]
|
339 |
+
scores_t2i = scores[1].T # Make it N_ims x N_text
|
340 |
+
else:
|
341 |
+
scores_t2i = scores
|
342 |
+
scores_i2t = scores
|
343 |
+
|
344 |
+
preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
|
345 |
+
correct_mask = (preds == 0)
|
346 |
+
result_records = [{"Precision@1": np.mean(correct_mask)}]
|
347 |
+
return result_records
|
348 |
+
|
349 |
+
|
350 |
+
def get_visual_genome_relation(image_preprocess, text_perturb_fn=None, image_perturb_fn=None, download=False):
|
351 |
+
return VG_Relation(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download)
|
352 |
+
|
353 |
+
|
354 |
+
def get_visual_genome_attribution(image_preprocess, text_perturb_fn=None, image_perturb_fn=None, download=False):
|
355 |
+
return VG_Attribution(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn,
|
356 |
+
image_perturb_fn=image_perturb_fn, download=download)
|
357 |
+
|
358 |
+
def get_coco_order(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"):
|
359 |
+
return COCO_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
|
360 |
+
download=download)
|
361 |
+
|
362 |
+
def get_flickr30k_order(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"):
|
363 |
+
return Flickr30k_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
|
364 |
+
download=download)
|
365 |
+
|
multimodal/build/lib/open_flamingo/eval/dataset_zoo/constants.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ARO_ROOT = "~/.cache/prerelease_bow"
|
2 |
+
COCO_ROOT = "~/.cache/coco/2014"
|
3 |
+
FLICKR_ROOT = "~/.cache/flickr30k/images"
|
multimodal/build/lib/open_flamingo/eval/dataset_zoo/perturbations.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from functools import partial
|
5 |
+
import torch.nn.functional as nnf
|
6 |
+
from torchvision import transforms as T
|
7 |
+
|
8 |
+
# A lot of the approaches here are inspired from the wonderful paper from O'Connor and Andreas 2021.
|
9 |
+
# https://github.com/lingo-mit/context-ablations
|
10 |
+
|
11 |
+
def get_text_perturb_fn(text_perturb_fn):
|
12 |
+
if text_perturb_fn == "shuffle_nouns_and_adj":
|
13 |
+
return shuffle_nouns_and_adj
|
14 |
+
elif text_perturb_fn == "shuffle_allbut_nouns_and_adj":
|
15 |
+
return shuffle_allbut_nouns_and_adj
|
16 |
+
elif text_perturb_fn == "shuffle_within_trigrams":
|
17 |
+
return shuffle_within_trigrams
|
18 |
+
elif text_perturb_fn == "shuffle_all_words":
|
19 |
+
return shuffle_all_words
|
20 |
+
elif text_perturb_fn == "shuffle_trigrams":
|
21 |
+
return shuffle_trigrams
|
22 |
+
elif text_perturb_fn is None:
|
23 |
+
return None
|
24 |
+
else:
|
25 |
+
print("Unknown text perturbation function: {}, returning None".format(text_perturb_fn))
|
26 |
+
return None
|
27 |
+
|
28 |
+
|
29 |
+
def get_image_perturb_fn(image_perturb_fn):
|
30 |
+
if image_perturb_fn == "shuffle_rows_4":
|
31 |
+
return partial(shuffle_rows, n_rows=4)
|
32 |
+
elif image_perturb_fn == "shuffle_patches_9":
|
33 |
+
return partial(shuffle_patches, n_ratio=3)
|
34 |
+
elif image_perturb_fn == "shuffle_cols_4":
|
35 |
+
return partial(shuffle_columns, n_cols=4)
|
36 |
+
elif image_perturb_fn is None:
|
37 |
+
return None
|
38 |
+
else:
|
39 |
+
print("Unknown image perturbation function: {}, returning None".format(image_perturb_fn))
|
40 |
+
return None
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class TextShuffler:
|
45 |
+
|
46 |
+
def __init__(self):
|
47 |
+
import spacy
|
48 |
+
self.nlp = spacy.load("en_core_web_sm")
|
49 |
+
|
50 |
+
def shuffle_nouns_and_adj(self, ex):
|
51 |
+
|
52 |
+
doc = self.nlp(ex)
|
53 |
+
tokens = [token.text for token in doc]
|
54 |
+
text = np.array(tokens)
|
55 |
+
noun_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS']]
|
56 |
+
## Finding adjectives
|
57 |
+
adjective_idx = [i for i, token in enumerate(doc) if token.tag_ in ['JJ', 'JJR', 'JJS']]
|
58 |
+
## Shuffle the nouns of the text
|
59 |
+
text[noun_idx] = np.random.permutation(text[noun_idx])
|
60 |
+
## Shuffle the adjectives of the text
|
61 |
+
text[adjective_idx] = np.random.permutation(text[adjective_idx])
|
62 |
+
|
63 |
+
return " ".join(text)
|
64 |
+
|
65 |
+
def shuffle_all_words(self, ex):
|
66 |
+
return " ".join(np.random.permutation(ex.split(" ")))
|
67 |
+
|
68 |
+
|
69 |
+
def shuffle_allbut_nouns_and_adj(self, ex):
|
70 |
+
doc = self.nlp(ex)
|
71 |
+
tokens = [token.text for token in doc]
|
72 |
+
text = np.array(tokens)
|
73 |
+
noun_adj_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS', 'JJ', 'JJR', 'JJS']]
|
74 |
+
## Finding adjectives
|
75 |
+
|
76 |
+
else_idx = np.ones(text.shape[0])
|
77 |
+
else_idx[noun_adj_idx] = 0
|
78 |
+
|
79 |
+
else_idx = else_idx.astype(bool)
|
80 |
+
## Shuffle everything that are nouns or adjectives
|
81 |
+
text[else_idx] = np.random.permutation(text[else_idx])
|
82 |
+
return " ".join(text)
|
83 |
+
|
84 |
+
|
85 |
+
def get_trigrams(self, sentence):
|
86 |
+
# Taken from https://github.com/lingo-mit/context-ablations/blob/478fb18a9f9680321f0d37dc999ea444e9287cc0/code/transformers/src/transformers/data/data_augmentation.py
|
87 |
+
trigrams = []
|
88 |
+
trigram = []
|
89 |
+
for i in range(len(sentence)):
|
90 |
+
trigram.append(sentence[i])
|
91 |
+
if i % 3 == 2:
|
92 |
+
trigrams.append(trigram[:])
|
93 |
+
trigram = []
|
94 |
+
if trigram:
|
95 |
+
trigrams.append(trigram)
|
96 |
+
return trigrams
|
97 |
+
|
98 |
+
def trigram_shuffle(self, sentence):
|
99 |
+
trigrams = self.get_trigrams(sentence)
|
100 |
+
for trigram in trigrams:
|
101 |
+
random.shuffle(trigram)
|
102 |
+
return " ".join([" ".join(trigram) for trigram in trigrams])
|
103 |
+
|
104 |
+
|
105 |
+
def shuffle_within_trigrams(self, ex):
|
106 |
+
import nltk
|
107 |
+
tokens = nltk.word_tokenize(ex)
|
108 |
+
shuffled_ex = self.trigram_shuffle(tokens)
|
109 |
+
return shuffled_ex
|
110 |
+
|
111 |
+
|
112 |
+
def shuffle_trigrams(self, ex):
|
113 |
+
import nltk
|
114 |
+
tokens = nltk.word_tokenize(ex)
|
115 |
+
trigrams = self.get_trigrams(tokens)
|
116 |
+
random.shuffle(trigrams)
|
117 |
+
shuffled_ex = " ".join([" ".join(trigram) for trigram in trigrams])
|
118 |
+
return shuffled_ex
|
119 |
+
|
120 |
+
|
121 |
+
def _handle_image_4shuffle(x):
|
122 |
+
return_image = False
|
123 |
+
if not isinstance(x, torch.Tensor):
|
124 |
+
# print(f"x is not a tensor: {type(x)}. Trying to handle but fix this or I'll annoy you with this log")
|
125 |
+
t = torch.tensor(np.array(x)).unsqueeze(dim=0).float()
|
126 |
+
t = t.permute(0, 3, 1, 2)
|
127 |
+
return_image = True
|
128 |
+
return t, return_image
|
129 |
+
if len(x.shape) != 4:
|
130 |
+
#print("You did not send a tensor of shape NxCxWxH. Unsqueezing not but fix this or I'll annoy you with this log")
|
131 |
+
return x.unsqueeze(dim=0), return_image
|
132 |
+
else:
|
133 |
+
# Good boi
|
134 |
+
return x, return_image
|
135 |
+
|
136 |
+
|
137 |
+
def shuffle_rows(x, n_rows=7):
|
138 |
+
"""
|
139 |
+
Shuffle the rows of the image tensor where each row has a size of 14 pixels.
|
140 |
+
Tensor is of shape N x C x W x H
|
141 |
+
"""
|
142 |
+
x, return_image = _handle_image_4shuffle(x)
|
143 |
+
patch_size = x.shape[-2]//n_rows
|
144 |
+
u = nnf.unfold(x, kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
|
145 |
+
# permute the patches of each image in the batch
|
146 |
+
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
|
147 |
+
# fold the permuted patches back together
|
148 |
+
f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
|
149 |
+
|
150 |
+
image = f.squeeze() # C W H
|
151 |
+
if return_image:
|
152 |
+
return T.ToPILImage()(image.type(torch.uint8))
|
153 |
+
else:
|
154 |
+
return image
|
155 |
+
|
156 |
+
|
157 |
+
def shuffle_columns(x, n_cols=7):
|
158 |
+
"""
|
159 |
+
Shuffle the columns of the image tensor where we'll have n_cols columns.
|
160 |
+
Tensor is of shape N x C x W x H
|
161 |
+
"""
|
162 |
+
x, return_image = _handle_image_4shuffle(x)
|
163 |
+
patch_size = x.shape[-1]//n_cols
|
164 |
+
u = nnf.unfold(x, kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
|
165 |
+
# permute the patches of each image in the batch
|
166 |
+
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
|
167 |
+
# fold the permuted patches back together
|
168 |
+
f = nnf.fold(pu, x.shape[-2:], kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
|
169 |
+
image = f.squeeze() # C W H
|
170 |
+
if return_image:
|
171 |
+
return T.ToPILImage()(image.type(torch.uint8))
|
172 |
+
else:
|
173 |
+
return image
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
def shuffle_patches(x, n_ratio=4):
|
178 |
+
"""
|
179 |
+
Shuffle the rows of the image tensor where each row has a size of 14 pixels.
|
180 |
+
Tensor is of shape N x C x W x H
|
181 |
+
"""
|
182 |
+
x, return_image = _handle_image_4shuffle(x)
|
183 |
+
patch_size_x = x.shape[-2]//n_ratio
|
184 |
+
patch_size_y = x.shape[-1]//n_ratio
|
185 |
+
u = nnf.unfold(x, kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
|
186 |
+
# permute the patches of each image in the batch
|
187 |
+
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
|
188 |
+
# fold the permuted patches back together
|
189 |
+
f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
|
190 |
+
image = f.squeeze() # C W H
|
191 |
+
if return_image:
|
192 |
+
return T.ToPILImage()(image.type(torch.uint8))
|
193 |
+
else:
|
194 |
+
return image
|
multimodal/build/lib/open_flamingo/eval/dataset_zoo/retrieval.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torchvision.datasets.utils import download_url
|
10 |
+
|
11 |
+
from .constants import COCO_ROOT, FLICKR_ROOT
|
12 |
+
from .utils import AverageMeter
|
13 |
+
|
14 |
+
|
15 |
+
def pre_caption(caption,max_words=50):
|
16 |
+
caption = re.sub(
|
17 |
+
r"([.!\"()*#:;~])",
|
18 |
+
' ',
|
19 |
+
caption.lower(),
|
20 |
+
)
|
21 |
+
caption = re.sub(
|
22 |
+
r"\s{2,}",
|
23 |
+
' ',
|
24 |
+
caption,
|
25 |
+
)
|
26 |
+
caption = caption.rstrip('\n')
|
27 |
+
caption = caption.strip(' ')
|
28 |
+
|
29 |
+
#truncate caption
|
30 |
+
caption_words = caption.split(' ')
|
31 |
+
if len(caption_words)>max_words:
|
32 |
+
caption = ' '.join(caption_words[:max_words])
|
33 |
+
|
34 |
+
return caption
|
35 |
+
|
36 |
+
|
37 |
+
class COCO_Retrieval(Dataset):
|
38 |
+
def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test",
|
39 |
+
image_perturb_fn=None, download=False):
|
40 |
+
"""
|
41 |
+
COCO Retrieval Dataset.
|
42 |
+
image_preprocess: image preprocessing function
|
43 |
+
root_dir: The directory of the coco dataset. This directory should contain test2014 files.
|
44 |
+
max_words: Cropping the caption to max_words.
|
45 |
+
split: 'val' or 'test'
|
46 |
+
image_perturb_fn: image perturbation function for patch permutation experiments.
|
47 |
+
download: Whether to download the dataset if it does not exist.
|
48 |
+
"""
|
49 |
+
self.root_dir = root_dir
|
50 |
+
if not os.path.exists(root_dir):
|
51 |
+
print("Directory for COCO could not be found!")
|
52 |
+
if download:
|
53 |
+
print("Downloading COCO now.")
|
54 |
+
self.download()
|
55 |
+
else:
|
56 |
+
raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
|
57 |
+
|
58 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
59 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
60 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
61 |
+
download_url(urls[split],root_dir)
|
62 |
+
|
63 |
+
|
64 |
+
self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
|
65 |
+
self.image_preprocess = image_preprocess
|
66 |
+
self.image_perturb_fn = image_perturb_fn
|
67 |
+
self.image_root = root_dir
|
68 |
+
|
69 |
+
self.text = []
|
70 |
+
self.image = []
|
71 |
+
self.txt2img = {}
|
72 |
+
self.img2txt = {}
|
73 |
+
|
74 |
+
txt_id = 0
|
75 |
+
for img_id, ann in enumerate(self.annotation):
|
76 |
+
self.image.append(ann['image'])
|
77 |
+
self.img2txt[img_id] = []
|
78 |
+
for i, caption in enumerate(ann['caption']):
|
79 |
+
self.text.append(pre_caption(caption,max_words))
|
80 |
+
self.img2txt[img_id].append(txt_id)
|
81 |
+
self.txt2img[txt_id] = img_id
|
82 |
+
txt_id += 1
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.annotation)
|
86 |
+
|
87 |
+
def __getitem__(self, index):
|
88 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
89 |
+
image = Image.open(image_path).convert('RGB')
|
90 |
+
|
91 |
+
if self.image_preprocess is not None:
|
92 |
+
image = self.image_preprocess(image)
|
93 |
+
|
94 |
+
if self.image_perturb_fn is not None:
|
95 |
+
image = self.image_perturb_fn(image)
|
96 |
+
|
97 |
+
return {"image": image, "idx": index}
|
98 |
+
|
99 |
+
def download(self):
|
100 |
+
import subprocess
|
101 |
+
os.makedirs(self.root_dir, exist_ok=True)
|
102 |
+
#subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
|
103 |
+
#subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
|
104 |
+
|
105 |
+
subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
|
106 |
+
subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
|
107 |
+
|
108 |
+
subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
|
109 |
+
subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
|
110 |
+
|
111 |
+
|
112 |
+
def evaluate_scores(self, scores):
|
113 |
+
if isinstance(scores, tuple):
|
114 |
+
scores_i2t = scores[0]
|
115 |
+
scores_t2i = scores[1].T # Make it N_ims x N_text
|
116 |
+
|
117 |
+
else:
|
118 |
+
scores_t2i = scores
|
119 |
+
scores_i2t = scores
|
120 |
+
|
121 |
+
print(f"COCO results across {scores_i2t.shape} samples. ")
|
122 |
+
prec_at_1 = AverageMeter()
|
123 |
+
prec_at_5 = AverageMeter()
|
124 |
+
|
125 |
+
# Text retrieval
|
126 |
+
tqdm_iterator = tqdm(range(len(self.img2txt)))
|
127 |
+
for i in tqdm_iterator:
|
128 |
+
top5_captions = np.argsort(scores_i2t[i])[-5:]
|
129 |
+
true_captions = self.img2txt[i]
|
130 |
+
|
131 |
+
prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
|
132 |
+
prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)
|
133 |
+
|
134 |
+
tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")
|
135 |
+
|
136 |
+
# Image Retrieval
|
137 |
+
image_prec_at_1 = AverageMeter()
|
138 |
+
image_prec_at_5 = AverageMeter()
|
139 |
+
|
140 |
+
tqdm_iterator = tqdm(range(len(self.txt2img)))
|
141 |
+
for i in tqdm_iterator:
|
142 |
+
top5_images = np.argsort(scores_t2i[:, i])[-5:]
|
143 |
+
true_image = self.txt2img[i]
|
144 |
+
|
145 |
+
image_prec_at_1.update(true_image in top5_images[-1:])
|
146 |
+
image_prec_at_5.update(true_image in top5_images)
|
147 |
+
|
148 |
+
tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")
|
149 |
+
|
150 |
+
records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}]
|
151 |
+
return records
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
class Flickr30k_Retrieval(Dataset):
|
156 |
+
def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30,
|
157 |
+
image_perturb_fn=None, *args, **kwargs):
|
158 |
+
'''
|
159 |
+
Flickr30k dataset for retrieval.
|
160 |
+
image_preprocess: image preprocessing function
|
161 |
+
root_dir: The directory of the coco dataset. This directory should contain test2014 files.
|
162 |
+
max_words: Cropping the caption to max_words.
|
163 |
+
split: 'val' or 'test'
|
164 |
+
image_perturb_fn: image perturbation function for patch permutation experiments.
|
165 |
+
download: Whether to download the dataset if it does not exist.
|
166 |
+
'''
|
167 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
|
168 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
|
169 |
+
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
|
170 |
+
|
171 |
+
if not os.path.exists(root_dir):
|
172 |
+
print("Directory for Flickr30k could not be found!")
|
173 |
+
flickr_url = "https://forms.illinois.edu/sec/229675"
|
174 |
+
raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
|
175 |
+
|
176 |
+
download_url(urls[split],root_dir)
|
177 |
+
|
178 |
+
self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
|
179 |
+
self.image_preprocess = image_preprocess
|
180 |
+
self.image_perturb_fn = image_perturb_fn
|
181 |
+
self.root_dir = root_dir
|
182 |
+
|
183 |
+
self.text = []
|
184 |
+
self.image = []
|
185 |
+
self.txt2img = {}
|
186 |
+
self.img2txt = {}
|
187 |
+
|
188 |
+
txt_id = 0
|
189 |
+
for img_id, ann in enumerate(self.annotation):
|
190 |
+
self.image.append(ann['image'])
|
191 |
+
self.img2txt[img_id] = []
|
192 |
+
for i, caption in enumerate(ann['caption']):
|
193 |
+
self.text.append(pre_caption(caption,max_words))
|
194 |
+
self.img2txt[img_id].append(txt_id)
|
195 |
+
self.txt2img[txt_id] = img_id
|
196 |
+
txt_id += 1
|
197 |
+
|
198 |
+
def __len__(self):
|
199 |
+
return len(self.annotation)
|
200 |
+
|
201 |
+
def __getitem__(self, index):
|
202 |
+
image_path = os.path.join(self.root_dir, self.annotation[index]['image'])
|
203 |
+
image = Image.open(image_path).convert('RGB')
|
204 |
+
if self.image_preprocess is not None:
|
205 |
+
image = self.image_preprocess(image)
|
206 |
+
if self.image_perturb_fn is not None:
|
207 |
+
image = self.image_perturb_fn(image)
|
208 |
+
|
209 |
+
return {"image": image, "idx": index}
|
210 |
+
|
211 |
+
def evaluate_scores(self, scores):
|
212 |
+
if isinstance(scores, tuple):
|
213 |
+
scores_i2t = scores[0]
|
214 |
+
scores_t2i = scores[1].T # Make it N_ims x N_text
|
215 |
+
|
216 |
+
else:
|
217 |
+
scores_t2i = scores
|
218 |
+
scores_i2t = scores
|
219 |
+
|
220 |
+
print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ")
|
221 |
+
prec_at_1 = AverageMeter()
|
222 |
+
prec_at_5 = AverageMeter()
|
223 |
+
|
224 |
+
# Text retrieval
|
225 |
+
tqdm_iterator = tqdm(range(len(self.img2txt)))
|
226 |
+
for i in tqdm_iterator:
|
227 |
+
top5_captions = np.argsort(scores_i2t[i])[-5:]
|
228 |
+
true_captions = self.img2txt[i]
|
229 |
+
|
230 |
+
prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
|
231 |
+
prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)
|
232 |
+
|
233 |
+
tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")
|
234 |
+
|
235 |
+
# Image Retrieval
|
236 |
+
image_prec_at_1 = AverageMeter()
|
237 |
+
image_prec_at_5 = AverageMeter()
|
238 |
+
|
239 |
+
tqdm_iterator = tqdm(range(len(self.txt2img)))
|
240 |
+
for i in tqdm_iterator:
|
241 |
+
top5_images = np.argsort(scores_t2i[:, i])[-5:]
|
242 |
+
true_image = self.txt2img[i]
|
243 |
+
|
244 |
+
image_prec_at_1.update(true_image in top5_images[-1:])
|
245 |
+
image_prec_at_5.update(true_image in top5_images)
|
246 |
+
|
247 |
+
tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")
|
248 |
+
|
249 |
+
records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}]
|
250 |
+
return records
|
251 |
+
|
252 |
+
def download(self):
|
253 |
+
raise NotImplementedError("Flickr30k dataset is not available for download.")
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
def get_coco_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"):
|
258 |
+
dataset = COCO_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
|
259 |
+
download=download)
|
260 |
+
return dataset
|
261 |
+
|
262 |
+
|
263 |
+
def get_flickr30k_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"):
|
264 |
+
dataset = Flickr30k_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
|
265 |
+
download=download)
|
266 |
+
return dataset
|
multimodal/build/lib/open_flamingo/eval/dataset_zoo/utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class AverageMeter(object):
|
2 |
+
def __init__(self):
|
3 |
+
self.reset()
|
4 |
+
|
5 |
+
def reset(self):
|
6 |
+
self.val = 0
|
7 |
+
self.avg = 0
|
8 |
+
self.sum = 0
|
9 |
+
self.count = 0
|
10 |
+
|
11 |
+
def update(self, val, n=1):
|
12 |
+
self.val = val
|
13 |
+
self.sum += val * n
|
14 |
+
self.count += n
|
15 |
+
self.avg = self.sum / self.count
|
multimodal/build/lib/open_flamingo/eval/eval_datasets.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.datasets import ImageFolder
|
7 |
+
|
8 |
+
from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
|
9 |
+
|
10 |
+
|
11 |
+
class COCOFlickrDataset(Dataset):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
image_dir_path,
|
15 |
+
annotations_path,
|
16 |
+
is_flickr=False,
|
17 |
+
):
|
18 |
+
self.image_dir_path = image_dir_path
|
19 |
+
self.annotations = json.load(open(annotations_path))["annotations"]
|
20 |
+
self.is_flickr = is_flickr
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.annotations)
|
24 |
+
|
25 |
+
def get_img_path(self, idx):
|
26 |
+
if self.is_flickr:
|
27 |
+
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
|
28 |
+
else:
|
29 |
+
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg"
|
30 |
+
|
31 |
+
def __getitem__(self, idx):
|
32 |
+
image = Image.open(self.get_img_path(idx))
|
33 |
+
caption = self.annotations[idx]["caption"]
|
34 |
+
return {
|
35 |
+
"image": image,
|
36 |
+
"caption": caption,
|
37 |
+
"image_id": self.annotations[idx]["image_id"],
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class VQADataset(Dataset):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/",
|
45 |
+
question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json",
|
46 |
+
annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json",
|
47 |
+
vqa_dataset="vqa",
|
48 |
+
):
|
49 |
+
self.questions = json.load(open(question_path, "r"))["questions"]
|
50 |
+
self.answers = json.load(open(annotations_path, "r"))["annotations"]
|
51 |
+
self.image_dir_path = image_dir_path
|
52 |
+
self.vqa_dataset = vqa_dataset
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.questions)
|
56 |
+
|
57 |
+
def get_img_path(self, question):
|
58 |
+
if self.vqa_dataset == "vqa":
|
59 |
+
return os.path.join(
|
60 |
+
self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
|
61 |
+
)
|
62 |
+
elif self.vqa_dataset == "ok_vqa":
|
63 |
+
return os.path.join(
|
64 |
+
self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
raise Exception(f"Unknown VQA dataset {self.vqa_dataset}")
|
68 |
+
|
69 |
+
def __getitem__(self, idx):
|
70 |
+
question = self.questions[idx]
|
71 |
+
answers = self.answers[idx]
|
72 |
+
img_path = self.get_img_path(question)
|
73 |
+
image = Image.open(img_path)
|
74 |
+
return {
|
75 |
+
"image": image,
|
76 |
+
"question": question["question"],
|
77 |
+
"answers": [a["answer"] for a in answers["answers"]],
|
78 |
+
"question_id": question["question_id"],
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
class ImageNetDataset(ImageFolder):
|
83 |
+
"""Class to represent the ImageNet1k dataset."""
|
84 |
+
|
85 |
+
def __init__(self, root, **kwargs):
|
86 |
+
super().__init__(root=root, **kwargs)
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
sample, target = super().__getitem__(idx)
|
90 |
+
target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
|
91 |
+
return {
|
92 |
+
"image": sample,
|
93 |
+
"class_id": target, # numeric ID of the ImageNet class
|
94 |
+
"class_name": target_label, # human-readable name of ImageNet class
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
gqa_dataset = GQADataset()
|
100 |
+
for sample in gqa_dataset:
|
101 |
+
print(sample)
|
multimodal/build/lib/open_flamingo/eval/evaluate.py
ADDED
@@ -0,0 +1,1435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from math import ceil
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable
|
9 |
+
import time
|
10 |
+
import cv2
|
11 |
+
import webdataset as wds
|
12 |
+
from sklearn.metrics import recall_score, average_precision_score
|
13 |
+
|
14 |
+
import more_itertools
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
18 |
+
from eval_datasets import VQADataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from collections import Counter
|
21 |
+
|
22 |
+
from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
|
23 |
+
from open_flamingo.eval.classification import (
|
24 |
+
compute_per_sample_probs,
|
25 |
+
compute_per_sample_loss,
|
26 |
+
)
|
27 |
+
from open_flamingo.eval.imagenet_utils import (
|
28 |
+
openai_imagenet_classnames,
|
29 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL,
|
30 |
+
)
|
31 |
+
|
32 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
33 |
+
from PIL import Image
|
34 |
+
from io import BytesIO
|
35 |
+
import base64
|
36 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
37 |
+
import string
|
38 |
+
from open_flamingo.eval.task.reg import evaluate_reg
|
39 |
+
from open_flamingo.eval.task.gqa import GQADataset
|
40 |
+
from open_flamingo.eval.task.vl_checklist import evaluate_vlc
|
41 |
+
from open_flamingo.eval.task.crepe import evaluate_crepe
|
42 |
+
from open_flamingo.eval.task.caption import evaluate_coco_flickr
|
43 |
+
from open_flamingo.eval.task.utils import is_correct, get_iou
|
44 |
+
from open_flamingo.eval.task.cola import evaluate_cola
|
45 |
+
from open_flamingo.eval.task.gqa import evaluate_gqa
|
46 |
+
|
47 |
+
def expand2square(pil_img, background_color):
|
48 |
+
width, height = pil_img.size
|
49 |
+
if width == height:
|
50 |
+
return pil_img
|
51 |
+
elif width > height:
|
52 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
53 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
54 |
+
return result
|
55 |
+
else:
|
56 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
57 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
58 |
+
return result
|
59 |
+
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
|
62 |
+
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
|
63 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
64 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
65 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
66 |
+
parser.add_argument(
|
67 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
68 |
+
)
|
69 |
+
|
70 |
+
# Trial arguments
|
71 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
72 |
+
parser.add_argument(
|
73 |
+
"--num_trials",
|
74 |
+
type=int,
|
75 |
+
default=1,
|
76 |
+
help="Number of trials to run for each shot using different demonstrations",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--trial_seeds",
|
80 |
+
nargs="+",
|
81 |
+
default=[0],
|
82 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
|
86 |
+
)
|
87 |
+
|
88 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
89 |
+
|
90 |
+
# Per-dataset evaluation flags
|
91 |
+
parser.add_argument(
|
92 |
+
"--eval_coco",
|
93 |
+
action="store_true",
|
94 |
+
default=False,
|
95 |
+
help="Whether to evaluate on COCO.",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--eval_vqav2",
|
99 |
+
action="store_true",
|
100 |
+
default=False,
|
101 |
+
help="Whether to evaluate on VQAV2.",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--eval_ok_vqa",
|
105 |
+
action="store_true",
|
106 |
+
default=False,
|
107 |
+
help="Whether to evaluate on OK-VQA.",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--eval_imagenet",
|
111 |
+
action="store_true",
|
112 |
+
default=False,
|
113 |
+
help="Whether to evaluate on ImageNet.",
|
114 |
+
)
|
115 |
+
|
116 |
+
parser.add_argument(
|
117 |
+
"--eval_flickr30",
|
118 |
+
action="store_true",
|
119 |
+
default=False,
|
120 |
+
help="Whether to evaluate on Flickr30.",
|
121 |
+
)
|
122 |
+
|
123 |
+
parser.add_argument(
|
124 |
+
"--eval_refcoco",
|
125 |
+
action="store_true",
|
126 |
+
default=False,
|
127 |
+
help="Whether to evaluate on RefCOCO.",
|
128 |
+
)
|
129 |
+
|
130 |
+
# Dataset arguments
|
131 |
+
|
132 |
+
## Flickr30 Dataset
|
133 |
+
parser.add_argument(
|
134 |
+
"--flickr_image_dir_path",
|
135 |
+
type=str,
|
136 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
137 |
+
default=None,
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--flickr_annotations_json_path",
|
141 |
+
type=str,
|
142 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
143 |
+
default=None,
|
144 |
+
)
|
145 |
+
|
146 |
+
## COCO Dataset
|
147 |
+
parser.add_argument(
|
148 |
+
"--coco_image_dir_path",
|
149 |
+
type=str,
|
150 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
151 |
+
default=None,
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--coco_annotations_json_path",
|
155 |
+
type=str,
|
156 |
+
default=None,
|
157 |
+
)
|
158 |
+
|
159 |
+
## VQAV2 Dataset
|
160 |
+
parser.add_argument(
|
161 |
+
"--vqav2_image_dir_path",
|
162 |
+
type=str,
|
163 |
+
default=None,
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--vqav2_questions_json_path",
|
167 |
+
type=str,
|
168 |
+
default=None,
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--vqav2_annotations_json_path",
|
172 |
+
type=str,
|
173 |
+
default=None,
|
174 |
+
)
|
175 |
+
|
176 |
+
## OK-VQA Dataset
|
177 |
+
parser.add_argument(
|
178 |
+
"--ok_vqa_image_dir_path",
|
179 |
+
type=str,
|
180 |
+
help="Path to the vqav2/train2014 directory.",
|
181 |
+
default=None,
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--ok_vqa_questions_json_path",
|
185 |
+
type=str,
|
186 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
187 |
+
default=None,
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--ok_vqa_annotations_json_path",
|
191 |
+
type=str,
|
192 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
193 |
+
default=None,
|
194 |
+
)
|
195 |
+
|
196 |
+
## Imagenet dataset
|
197 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
198 |
+
|
199 |
+
## RefCOCO dataset
|
200 |
+
parser.add_argument("--refcoco_tsvfile", type=str, default=None)
|
201 |
+
|
202 |
+
parser.add_argument(
|
203 |
+
"--location_token_num",
|
204 |
+
default=1000,
|
205 |
+
type=int,
|
206 |
+
)
|
207 |
+
# distributed training
|
208 |
+
parser.add_argument(
|
209 |
+
"--dist-url",
|
210 |
+
default="env://",
|
211 |
+
type=str,
|
212 |
+
help="url used to set up distributed training",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
"--horovod",
|
219 |
+
default=False,
|
220 |
+
action="store_true",
|
221 |
+
help="Use horovod for distributed training.",
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--no-set-device-rank",
|
225 |
+
default=False,
|
226 |
+
action="store_true",
|
227 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--dist",
|
231 |
+
default=False,
|
232 |
+
action="store_true",
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--lora",
|
236 |
+
default=False,
|
237 |
+
action="store_true",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--lora_r",
|
241 |
+
default=16,
|
242 |
+
type=int,
|
243 |
+
required=False,
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--legacy",
|
247 |
+
default=False,
|
248 |
+
action="store_true",
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--special",
|
252 |
+
default=False,
|
253 |
+
action="store_true",
|
254 |
+
)
|
255 |
+
parser.add_argument(
|
256 |
+
"--id",
|
257 |
+
default=0,
|
258 |
+
type=int,
|
259 |
+
required=False,
|
260 |
+
)
|
261 |
+
|
262 |
+
parser.add_argument(
|
263 |
+
"--eval_gqa",
|
264 |
+
default=False,
|
265 |
+
action="store_true",
|
266 |
+
)
|
267 |
+
parser.add_argument(
|
268 |
+
"--use_sam",
|
269 |
+
default=None,
|
270 |
+
type=str,
|
271 |
+
required=False,
|
272 |
+
)
|
273 |
+
parser.add_argument(
|
274 |
+
"--add_visual_token",
|
275 |
+
default=False,
|
276 |
+
action="store_true",
|
277 |
+
)
|
278 |
+
parser.add_argument(
|
279 |
+
"--use_format_v2",
|
280 |
+
default=False,
|
281 |
+
action="store_true",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--eval_aro",
|
285 |
+
default=False,
|
286 |
+
action="store_true",
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--eval_pisc",
|
290 |
+
default=False,
|
291 |
+
action="store_true",
|
292 |
+
)
|
293 |
+
parser.add_argument(
|
294 |
+
"--eval_reg",
|
295 |
+
default=False,
|
296 |
+
action="store_true",
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--eval_vlc",
|
300 |
+
default=False,
|
301 |
+
action="store_true",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--eval_crepe",
|
305 |
+
default=False,
|
306 |
+
action="store_true",
|
307 |
+
)
|
308 |
+
parser.add_argument(
|
309 |
+
"--eval_cola",
|
310 |
+
default=False,
|
311 |
+
action="store_true",
|
312 |
+
)
|
313 |
+
parser.add_argument(
|
314 |
+
"--level",
|
315 |
+
default=4,
|
316 |
+
type=int,
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"--type",
|
320 |
+
default="swap",
|
321 |
+
type=str,
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--choose_left_right",
|
325 |
+
default=False,
|
326 |
+
action="store_true",
|
327 |
+
)
|
328 |
+
|
329 |
+
|
330 |
+
class OKVQAPostProcess():
|
331 |
+
def __init__(self):
|
332 |
+
self._lemmatizer = None
|
333 |
+
|
334 |
+
def _lemmatize(self, answers):
|
335 |
+
def apply(answer):
|
336 |
+
doc = self.lemmatizer(answer)
|
337 |
+
|
338 |
+
words = []
|
339 |
+
for token in doc:
|
340 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
341 |
+
words.append(token.lemma_)
|
342 |
+
else:
|
343 |
+
words.append(token.text)
|
344 |
+
answer = " ".join(words)
|
345 |
+
|
346 |
+
return answer
|
347 |
+
|
348 |
+
return [apply(answer) for answer in answers]
|
349 |
+
|
350 |
+
@property
|
351 |
+
def lemmatizer(self):
|
352 |
+
if self._lemmatizer is None:
|
353 |
+
try:
|
354 |
+
import spacy
|
355 |
+
|
356 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
357 |
+
except ImportError:
|
358 |
+
logging.error(
|
359 |
+
"""
|
360 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
361 |
+
python -m spacy download en_core_web_sm
|
362 |
+
OR
|
363 |
+
import spacy.cli
|
364 |
+
spacy.cli.download("en_core_web_sm")
|
365 |
+
"""
|
366 |
+
)
|
367 |
+
exit(1)
|
368 |
+
|
369 |
+
return self._lemmatizer
|
370 |
+
|
371 |
+
|
372 |
+
def main():
|
373 |
+
args = parser.parse_args()
|
374 |
+
if args.dist:
|
375 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
376 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
377 |
+
device_id = init_distributed_device(args)
|
378 |
+
else:
|
379 |
+
args.rank = 0
|
380 |
+
args.world_size = 1
|
381 |
+
print(f"rank: {args.rank} world_size: {args.world_size}")
|
382 |
+
|
383 |
+
if "sam" in args.checkpoint_path:
|
384 |
+
args.use_sam = "vit_l"
|
385 |
+
|
386 |
+
args.add_visual_token = True
|
387 |
+
if "lora" in args.checkpoint_path:
|
388 |
+
args.lora = True
|
389 |
+
|
390 |
+
|
391 |
+
args.add_pe = False
|
392 |
+
args.add_box = True
|
393 |
+
args.relation = False
|
394 |
+
args.enhance_data = False
|
395 |
+
args.use_format_v2 = True
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
import hashlib
|
400 |
+
args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
|
401 |
+
|
402 |
+
# load model
|
403 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
404 |
+
args.vision_encoder_path,
|
405 |
+
args.vision_encoder_pretrained,
|
406 |
+
args.lm_path,
|
407 |
+
args.lm_tokenizer_path,
|
408 |
+
location_token_num=args.location_token_num,
|
409 |
+
lora=args.lora,
|
410 |
+
lora_r=16,
|
411 |
+
use_sam=args.use_sam,
|
412 |
+
add_visual_token=args.add_visual_token,
|
413 |
+
use_format_v2=args.use_format_v2,
|
414 |
+
add_box=args.add_box,
|
415 |
+
add_pe=args.add_pe,
|
416 |
+
add_relation=args.relation,
|
417 |
+
enhance_data=args.enhance_data,
|
418 |
+
)
|
419 |
+
flamingo.use_format_v2 = args.use_format_v2
|
420 |
+
if args.special:
|
421 |
+
flamingo.special = True
|
422 |
+
else:
|
423 |
+
flamingo.special = False
|
424 |
+
if args.legacy:
|
425 |
+
flamingo.legacy = True
|
426 |
+
print("use legacy evaluation")
|
427 |
+
flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
|
428 |
+
flamingo.expr_name = args.checkpoint_path.split("/")[-2]
|
429 |
+
if args.rank == 0:
|
430 |
+
print("legacy", True if hasattr(flamingo, "legacy") else False)
|
431 |
+
print("step:", flamingo.step_num)
|
432 |
+
print("expr:", flamingo.expr_name)
|
433 |
+
print("use format v2:", flamingo.use_format_v2)
|
434 |
+
print(args)
|
435 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
436 |
+
model_state_dict = {}
|
437 |
+
for key in checkpoint["model_state_dict"].keys():
|
438 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
439 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
440 |
+
# previous checkpoint has some unnecessary weights
|
441 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
442 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
443 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
444 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
445 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
446 |
+
results = defaultdict(list)
|
447 |
+
if args.eval_coco:
|
448 |
+
print("Evaluating on COCO...")
|
449 |
+
cider_score = evaluate_coco_flickr(
|
450 |
+
model=flamingo,
|
451 |
+
tokenizer=tokenizer,
|
452 |
+
image_processor=image_processor,
|
453 |
+
batch_size=args.batch_size,
|
454 |
+
vis_embed_size=vis_embed_size,
|
455 |
+
rank=args.rank,
|
456 |
+
world_size=args.world_size,
|
457 |
+
id=args.id,
|
458 |
+
)
|
459 |
+
results["coco"].append({"score": cider_score})
|
460 |
+
|
461 |
+
if args.eval_ok_vqa:
|
462 |
+
print("Evaluating on OK-VQA...")
|
463 |
+
for shot in args.shots:
|
464 |
+
scores = []
|
465 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
466 |
+
ok_vqa_score = evaluate_vqa(
|
467 |
+
model=flamingo,
|
468 |
+
tokenizer=tokenizer,
|
469 |
+
image_processor=image_processor,
|
470 |
+
batch_size=args.batch_size,
|
471 |
+
image_dir_path=args.ok_vqa_image_dir_path,
|
472 |
+
questions_json_path=args.ok_vqa_questions_json_path,
|
473 |
+
annotations_json_path=args.ok_vqa_annotations_json_path,
|
474 |
+
vqa_dataset="ok_vqa",
|
475 |
+
vis_embed_size=vis_embed_size,
|
476 |
+
rank=args.rank,
|
477 |
+
world_size=args.world_size,
|
478 |
+
id=args.id,
|
479 |
+
)
|
480 |
+
results["ok_vqa"].append(
|
481 |
+
{"shots": shot, "score": ok_vqa_score}
|
482 |
+
)
|
483 |
+
|
484 |
+
if args.eval_vqav2:
|
485 |
+
print("Evaluating on VQAv2...")
|
486 |
+
for shot in args.shots:
|
487 |
+
scores = []
|
488 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
489 |
+
vqa_score = evaluate_vqa(
|
490 |
+
model=flamingo,
|
491 |
+
tokenizer=tokenizer,
|
492 |
+
image_processor=image_processor,
|
493 |
+
batch_size=args.batch_size,
|
494 |
+
image_dir_path=args.vqav2_image_dir_path,
|
495 |
+
questions_json_path=args.vqav2_questions_json_path,
|
496 |
+
annotations_json_path=args.vqav2_annotations_json_path,
|
497 |
+
vqa_dataset="vqa",
|
498 |
+
vis_embed_size=vis_embed_size,
|
499 |
+
rank=args.rank,
|
500 |
+
world_size=args.world_size,
|
501 |
+
id=args.id,
|
502 |
+
)
|
503 |
+
results["vqav2"].append(
|
504 |
+
{"shots": shot, "score": vqa_score}
|
505 |
+
)
|
506 |
+
|
507 |
+
if args.eval_gqa:
|
508 |
+
print("Evaluating on GQA...")
|
509 |
+
gqa_score = evaluate_gqa(
|
510 |
+
model=flamingo,
|
511 |
+
tokenizer=tokenizer,
|
512 |
+
image_processor=image_processor,
|
513 |
+
batch_size=args.batch_size,
|
514 |
+
vis_embed_size=vis_embed_size,
|
515 |
+
rank=args.rank,
|
516 |
+
world_size=args.world_size,
|
517 |
+
id=args.id,
|
518 |
+
)
|
519 |
+
results["gqa"].append(
|
520 |
+
{"score": gqa_score}
|
521 |
+
)
|
522 |
+
|
523 |
+
if args.eval_refcoco:
|
524 |
+
print("Evaluating on RefCOCO...")
|
525 |
+
refcoco_score = evaluate_refcoco(
|
526 |
+
model=flamingo,
|
527 |
+
tokenizer=tokenizer,
|
528 |
+
image_processor=image_processor,
|
529 |
+
batch_size=args.batch_size,
|
530 |
+
device=args.device,
|
531 |
+
tsvfile=args.refcoco_tsvfile,
|
532 |
+
vis_embed_size=vis_embed_size,
|
533 |
+
rank=args.rank,
|
534 |
+
world_size=args.world_size,
|
535 |
+
id=args.id,
|
536 |
+
)
|
537 |
+
results["refcoco"].append(
|
538 |
+
{"score": refcoco_score}
|
539 |
+
)
|
540 |
+
if args.eval_aro:
|
541 |
+
print("Evaluating on ARO...")
|
542 |
+
aro_score = evaluate_aro(
|
543 |
+
model=flamingo,
|
544 |
+
tokenizer=tokenizer,
|
545 |
+
image_processor=image_processor,
|
546 |
+
vis_embed_size=vis_embed_size,
|
547 |
+
rank=args.rank,
|
548 |
+
world_size=args.world_size,
|
549 |
+
id=args.id,
|
550 |
+
choose_left_right=args.choose_left_right,
|
551 |
+
)
|
552 |
+
results["aro"].append(
|
553 |
+
{"score": aro_score}
|
554 |
+
)
|
555 |
+
if args.eval_pisc:
|
556 |
+
print("Evaluating on ARO...")
|
557 |
+
aro_score = evaluate_pisc(
|
558 |
+
model=flamingo,
|
559 |
+
tokenizer=tokenizer,
|
560 |
+
image_processor=image_processor,
|
561 |
+
batch_size=args.batch_size,
|
562 |
+
device=args.device,
|
563 |
+
tsvfile=args.refcoco_tsvfile,
|
564 |
+
vis_embed_size=vis_embed_size,
|
565 |
+
rank=args.rank,
|
566 |
+
world_size=args.world_size,
|
567 |
+
id=args.id,
|
568 |
+
)
|
569 |
+
results["pisc"].append(
|
570 |
+
{"score": aro_score}
|
571 |
+
)
|
572 |
+
if args.eval_reg:
|
573 |
+
print("Evaluating on Referring Expression Generation...")
|
574 |
+
cider = evaluate_reg(
|
575 |
+
model=flamingo,
|
576 |
+
tokenizer=tokenizer,
|
577 |
+
image_processor=image_processor,
|
578 |
+
vis_embed_size=vis_embed_size,
|
579 |
+
rank=args.rank,
|
580 |
+
world_size=args.world_size,
|
581 |
+
id=args.id,
|
582 |
+
)
|
583 |
+
results["reg"].append(
|
584 |
+
{"score": cider}
|
585 |
+
)
|
586 |
+
if args.eval_vlc:
|
587 |
+
print("Evaluating on VL-checklist...")
|
588 |
+
vlc_score = evaluate_vlc(
|
589 |
+
model=flamingo,
|
590 |
+
tokenizer=tokenizer,
|
591 |
+
image_processor=image_processor,
|
592 |
+
vis_embed_size=vis_embed_size,
|
593 |
+
rank=args.rank,
|
594 |
+
world_size=args.world_size,
|
595 |
+
id=args.id,
|
596 |
+
)
|
597 |
+
results["vlc"].append(
|
598 |
+
{"score": vlc_score}
|
599 |
+
)
|
600 |
+
if args.eval_crepe:
|
601 |
+
print("Evaluating on CREPE...")
|
602 |
+
crepe_score = evaluate_crepe(
|
603 |
+
model=flamingo,
|
604 |
+
tokenizer=tokenizer,
|
605 |
+
image_processor=image_processor,
|
606 |
+
vis_embed_size=vis_embed_size,
|
607 |
+
rank=args.rank,
|
608 |
+
world_size=args.world_size,
|
609 |
+
id=args.id,
|
610 |
+
level=args.level,
|
611 |
+
type=args.type,
|
612 |
+
)
|
613 |
+
results["crepe"].append(
|
614 |
+
{"score": crepe_score}
|
615 |
+
)
|
616 |
+
if args.eval_cola:
|
617 |
+
print("Evaluating on COLA...")
|
618 |
+
cola_score = evaluate_cola(
|
619 |
+
model=flamingo,
|
620 |
+
tokenizer=tokenizer,
|
621 |
+
image_processor=image_processor,
|
622 |
+
vis_embed_size=vis_embed_size,
|
623 |
+
rank=args.rank,
|
624 |
+
world_size=args.world_size,
|
625 |
+
id=args.id,
|
626 |
+
)
|
627 |
+
results["cola"].append(
|
628 |
+
{"score": cola_score}
|
629 |
+
)
|
630 |
+
|
631 |
+
def prepare_batch_images(batch, image_processor):
|
632 |
+
batch_images = None
|
633 |
+
for b in batch:
|
634 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
635 |
+
if batch_images is None:
|
636 |
+
batch_images = b_image
|
637 |
+
else:
|
638 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
639 |
+
return batch_images
|
640 |
+
|
641 |
+
def get_outputs(
|
642 |
+
model,
|
643 |
+
batch_images,
|
644 |
+
attention_mask,
|
645 |
+
max_generation_length,
|
646 |
+
min_generation_length,
|
647 |
+
num_beams,
|
648 |
+
length_penalty,
|
649 |
+
input_ids,
|
650 |
+
image_start_index_list=None,
|
651 |
+
image_nums=None,
|
652 |
+
bad_words_ids=None,
|
653 |
+
):
|
654 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
655 |
+
outputs = model.generate(
|
656 |
+
batch_images,
|
657 |
+
input_ids,
|
658 |
+
attention_mask=attention_mask,
|
659 |
+
max_new_tokens=max_generation_length,
|
660 |
+
min_length=min_generation_length,
|
661 |
+
num_beams=num_beams,
|
662 |
+
length_penalty=length_penalty,
|
663 |
+
image_start_index_list=image_start_index_list,
|
664 |
+
image_nums=image_nums,
|
665 |
+
bad_words_ids=bad_words_ids,
|
666 |
+
)
|
667 |
+
|
668 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
669 |
+
return outputs
|
670 |
+
|
671 |
+
|
672 |
+
def evaluate_vqa(
|
673 |
+
model,
|
674 |
+
tokenizer,
|
675 |
+
image_processor,
|
676 |
+
batch_size,
|
677 |
+
image_dir_path=None,
|
678 |
+
questions_json_path=None,
|
679 |
+
annotations_json_path=None,
|
680 |
+
vqa_dataset="vqa",
|
681 |
+
vis_embed_size=None,
|
682 |
+
rank=0,
|
683 |
+
world_size=1,
|
684 |
+
id=0,
|
685 |
+
):
|
686 |
+
"""
|
687 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
688 |
+
|
689 |
+
Args:
|
690 |
+
model (nn.Module): model to evaluate
|
691 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
692 |
+
image_processor : image processor for the model
|
693 |
+
batch_size (int): batch size
|
694 |
+
image_dir_path (str): path to image directory
|
695 |
+
questions_json_path (str): path to questions json file
|
696 |
+
annotations_json_path (str): path to annotations json file
|
697 |
+
seed (int, optional): random seed. Defaults to 42.
|
698 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
699 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
700 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
701 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
702 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
703 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
704 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
705 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
706 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
707 |
+
Returns:
|
708 |
+
float: accuracy score
|
709 |
+
"""
|
710 |
+
if world_size > 1:
|
711 |
+
torch.distributed.barrier()
|
712 |
+
if vqa_dataset == "gqa":
|
713 |
+
eval_dataset = GQADataset()
|
714 |
+
else:
|
715 |
+
eval_dataset = VQADataset(
|
716 |
+
image_dir_path=image_dir_path,
|
717 |
+
question_path=questions_json_path,
|
718 |
+
annotations_path=annotations_json_path,
|
719 |
+
vqa_dataset=vqa_dataset,
|
720 |
+
)
|
721 |
+
postprocessor = OKVQAPostProcess()
|
722 |
+
try:
|
723 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
724 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
725 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
726 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
727 |
+
except:
|
728 |
+
pass
|
729 |
+
def get_prompt(sample):
|
730 |
+
return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
731 |
+
# return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
732 |
+
|
733 |
+
model.eval().cuda()
|
734 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
735 |
+
if "peft" in lang_encoder_name:
|
736 |
+
lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
737 |
+
predictions = []
|
738 |
+
tokenizer.padding_side = "left"
|
739 |
+
if world_size > 1:
|
740 |
+
torch.distributed.barrier()
|
741 |
+
this_tot = 0
|
742 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
743 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
|
744 |
+
)):
|
745 |
+
if ii % world_size != rank:
|
746 |
+
continue
|
747 |
+
batch_images = prepare_batch_images(
|
748 |
+
batch=batch,
|
749 |
+
image_processor=image_processor,
|
750 |
+
).cuda()
|
751 |
+
batch_text = [get_prompt(s) for s in batch]
|
752 |
+
encodings = tokenizer(
|
753 |
+
batch_text,
|
754 |
+
return_tensors="pt",
|
755 |
+
padding="longest",
|
756 |
+
truncation=True,
|
757 |
+
max_length=2000,
|
758 |
+
)
|
759 |
+
input_ids = encodings["input_ids"].cuda()
|
760 |
+
attention_mask = encodings["attention_mask"].cuda()
|
761 |
+
skip_special_tokens = True
|
762 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
763 |
+
if rank == 0:
|
764 |
+
tqdm.write("use legacy model")
|
765 |
+
for i in range(len(input_ids)):
|
766 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
767 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
768 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
769 |
+
input_ids[i, media_token_index] = pad_token_id
|
770 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
771 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
772 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
773 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
774 |
+
image_nums = [1] * len(input_ids)
|
775 |
+
if "llama" in lang_encoder_name:
|
776 |
+
attention_mask[input_ids == 0] = 0
|
777 |
+
outputs = get_outputs(
|
778 |
+
model=model,
|
779 |
+
batch_images=batch_images,
|
780 |
+
attention_mask=attention_mask,
|
781 |
+
max_generation_length=10,
|
782 |
+
min_generation_length=1,
|
783 |
+
num_beams=5,
|
784 |
+
length_penalty=0,
|
785 |
+
input_ids=input_ids,
|
786 |
+
image_start_index_list=image_start_index_list,
|
787 |
+
image_nums=image_nums,
|
788 |
+
)
|
789 |
+
# postprocess begin
|
790 |
+
new_predictions = [
|
791 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
792 |
+
]
|
793 |
+
if vqa_dataset == "ok_vqa":
|
794 |
+
new_predictions = postprocessor._lemmatize(new_predictions)
|
795 |
+
if model.special:
|
796 |
+
for i in range(len(new_predictions)):
|
797 |
+
for answer, _ in Counter(batch[i]['answers']).most_common():
|
798 |
+
if answer in new_predictions[i]:
|
799 |
+
new_predictions[i] = answer
|
800 |
+
break
|
801 |
+
if "cant" in new_predictions[i] and "no" == answer:
|
802 |
+
new_predictions[i] = answer
|
803 |
+
break
|
804 |
+
if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
|
805 |
+
new_predictions[i] = answer
|
806 |
+
break
|
807 |
+
|
808 |
+
this_tot += 1
|
809 |
+
if rank == 0 and this_tot % 20 == 0:
|
810 |
+
for i in range(1):
|
811 |
+
tqdm.write("model output: " + new_predictions[i])
|
812 |
+
|
813 |
+
predictions.extend(
|
814 |
+
[
|
815 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
816 |
+
for p, sample in zip(new_predictions, batch)
|
817 |
+
]
|
818 |
+
)
|
819 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
820 |
+
f.write(json.dumps(predictions))
|
821 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
822 |
+
|
823 |
+
time.sleep(10)
|
824 |
+
if world_size > 1:
|
825 |
+
torch.distributed.barrier()
|
826 |
+
if rank == 0:
|
827 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
828 |
+
predictions = []
|
829 |
+
for rank_i in range(world_size):
|
830 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
831 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
832 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
833 |
+
print("num:", len(predictions))
|
834 |
+
# save the predictions to a temporary file
|
835 |
+
random_uuid = str(uuid.uuid4())
|
836 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
837 |
+
f.write(json.dumps(predictions, indent=4))
|
838 |
+
|
839 |
+
if vqa_dataset == "gqa":
|
840 |
+
acc = compute_gqa_accuracy(predictions)
|
841 |
+
else:
|
842 |
+
acc = compute_vqa_accuracy(
|
843 |
+
f"{vqa_dataset}results_{random_uuid}.json",
|
844 |
+
questions_json_path,
|
845 |
+
annotations_json_path,
|
846 |
+
vqa_dataset=vqa_dataset,
|
847 |
+
)
|
848 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
849 |
+
os.makedirs("eval_results", exist_ok=True)
|
850 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
851 |
+
f.write(json.dumps(predictions, indent=2))
|
852 |
+
|
853 |
+
# delete the temporary file
|
854 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
855 |
+
else:
|
856 |
+
time.sleep(5)
|
857 |
+
acc = 0.0
|
858 |
+
if world_size > 1:
|
859 |
+
torch.distributed.barrier()
|
860 |
+
return acc
|
861 |
+
|
862 |
+
|
863 |
+
def evaluate_refcoco(
|
864 |
+
model,
|
865 |
+
tokenizer,
|
866 |
+
image_processor,
|
867 |
+
batch_size,
|
868 |
+
tsvfile,
|
869 |
+
max_generation_length=20,
|
870 |
+
num_beams=3,
|
871 |
+
length_penalty=-2.0,
|
872 |
+
device=-1,
|
873 |
+
vis_embed_size=None,
|
874 |
+
rank=0,
|
875 |
+
world_size=1,
|
876 |
+
id=0,
|
877 |
+
):
|
878 |
+
model.eval().cuda()
|
879 |
+
loc_token_ids = []
|
880 |
+
for i in range(1000):
|
881 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
882 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
883 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
884 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
885 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
886 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
887 |
+
object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
|
888 |
+
# all_ids = set(range(model.lang_encoder.lm_head.out_features))
|
889 |
+
# bad_words_ids = list(all_ids - set(loc_token_ids))
|
890 |
+
# bad_words_ids = [[b] for b in bad_words_ids]
|
891 |
+
# min_loc_token_id = min(loc_token_ids)
|
892 |
+
# max_loc_token_id = max(loc_token_ids)
|
893 |
+
total = 0
|
894 |
+
correct = 0
|
895 |
+
ious = []
|
896 |
+
if "refcocog" in tsvfile:
|
897 |
+
dataset_name = "refcocog"
|
898 |
+
elif "refcocoplus" in tsvfile:
|
899 |
+
dataset_name = "refcocoplus"
|
900 |
+
else:
|
901 |
+
dataset_name = "refcoco"
|
902 |
+
with open(tsvfile, "r") as f:
|
903 |
+
lines = f.readlines()
|
904 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
905 |
+
for ii, line in enumerate(pbar):
|
906 |
+
if ii % world_size != rank:
|
907 |
+
continue
|
908 |
+
total += 1
|
909 |
+
line = line.rstrip()
|
910 |
+
uniq_id, image_id, text, region_coord, image = line.split("\t")
|
911 |
+
|
912 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
913 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
|
914 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
|
915 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
|
916 |
+
|
917 |
+
gt_box = np.array(list(map(float, region_coord.split(","))))
|
918 |
+
width = image.width
|
919 |
+
height = image.height
|
920 |
+
image = image.resize((224, 224))
|
921 |
+
gt_box = gt_box / np.array([width, height, width, height]) * 224
|
922 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
923 |
+
text = text.rstrip('.').strip().replace('"', '').capitalize()
|
924 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text}<|#endofobject#|><|#visual#|>"]
|
925 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
|
926 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
|
927 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
|
928 |
+
|
929 |
+
encodings = tokenizer(
|
930 |
+
prompt,
|
931 |
+
padding="longest",
|
932 |
+
truncation=True,
|
933 |
+
return_tensors="pt",
|
934 |
+
max_length=2000,
|
935 |
+
)
|
936 |
+
input_ids = encodings["input_ids"]
|
937 |
+
attention_mask = encodings["attention_mask"]
|
938 |
+
# attention_mask[input_ids == prebox_token_id] = 0
|
939 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
940 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
941 |
+
image_nums = [1] * len(input_ids)
|
942 |
+
vision_x = batch_images.cuda()
|
943 |
+
lang_x = input_ids.cuda()
|
944 |
+
attention_mask = attention_mask.cuda()
|
945 |
+
|
946 |
+
model.debug_id = 0
|
947 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
948 |
+
outputs = model(
|
949 |
+
vision_x=vision_x,
|
950 |
+
lang_x=lang_x,
|
951 |
+
attention_mask=attention_mask,
|
952 |
+
labels=None,
|
953 |
+
image_nums=image_nums,
|
954 |
+
image_start_index_list=image_start_index_list,
|
955 |
+
added_bbox_list=None,
|
956 |
+
add_box=False,
|
957 |
+
)
|
958 |
+
boxes = outputs["boxes"]
|
959 |
+
scores = outputs["scores"]
|
960 |
+
boxes = boxes[scores >= scores[0]*0.5]
|
961 |
+
scores = scores[scores >= scores[0]*0.5]
|
962 |
+
|
963 |
+
text = text.lower().strip()
|
964 |
+
if text.split(" ")[0] not in ["a", "an", "the", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "several", "some"]:
|
965 |
+
text = "a " + text
|
966 |
+
losses = []
|
967 |
+
for box, score in zip(boxes, scores):
|
968 |
+
this_prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>There is<|#object#|><|#previsual#|><|#prebox#|><|#object#|> {text}"]
|
969 |
+
encodings = tokenizer(
|
970 |
+
this_prompt,
|
971 |
+
padding="longest",
|
972 |
+
truncation=True,
|
973 |
+
return_tensors="pt",
|
974 |
+
max_length=2000,
|
975 |
+
)
|
976 |
+
input_ids = encodings["input_ids"]
|
977 |
+
attention_mask = encodings["attention_mask"]
|
978 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
979 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
980 |
+
image_nums = [1] * len(input_ids)
|
981 |
+
vision_x = batch_images.cuda()
|
982 |
+
lang_x = input_ids.cuda()
|
983 |
+
attention_mask = attention_mask.cuda()
|
984 |
+
added_bbox_list = [torch.tensor(box / 224).cuda().unsqueeze(0).clamp(0, 0.99)]
|
985 |
+
labels = lang_x.clone()
|
986 |
+
start_idx = (lang_x == object_token_id).nonzero()[-1, -1]
|
987 |
+
labels[0, :start_idx+1] = -100
|
988 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
989 |
+
outputs = model(
|
990 |
+
vision_x=vision_x,
|
991 |
+
lang_x=lang_x,
|
992 |
+
attention_mask=attention_mask,
|
993 |
+
labels=labels,
|
994 |
+
image_nums=image_nums,
|
995 |
+
image_start_index_list=image_start_index_list,
|
996 |
+
added_bbox_list=added_bbox_list,
|
997 |
+
add_box=True,
|
998 |
+
)
|
999 |
+
# print(tokenizer.decode(outputs.logits[0, start_idx].sort(descending=True).indices[:10]))
|
1000 |
+
loss = outputs.loss.detach().cpu()
|
1001 |
+
losses.append((loss.sum() / (loss != 0).sum()).item())
|
1002 |
+
chosen_idx = np.array(losses).argmin()
|
1003 |
+
pred_box = boxes[chosen_idx]
|
1004 |
+
if chosen_idx != 0:
|
1005 |
+
tqdm.write(f"{text}|{chosen_idx}|{scores[chosen_idx]}")
|
1006 |
+
iou = get_iou(pred_box, gt_box)
|
1007 |
+
if iou >= 0.5:
|
1008 |
+
correct += 1
|
1009 |
+
# else:
|
1010 |
+
# if rank == 0:
|
1011 |
+
# tqdm.write(text.rstrip('.').strip().lower())
|
1012 |
+
# open_cv_image = np.array(image)
|
1013 |
+
# # Convert RGB to BGR
|
1014 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1015 |
+
# open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
1016 |
+
# open_cv_image = cv2.rectangle(open_cv_image, gt_box[:2].astype(int), gt_box[2:].astype(int), (0, 255, 0), 2)
|
1017 |
+
# cv2.imwrite(f"refcocog_result/{ii}_{iou}_{text}.jpg", open_cv_image)
|
1018 |
+
pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
|
1019 |
+
# open_cv_image = np.array(image)
|
1020 |
+
# # Convert RGB to BGR
|
1021 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1022 |
+
# for box, score in zip(boxes, scores):
|
1023 |
+
# open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
1024 |
+
# cv2.imwrite("output.jpg", open_cv_image)
|
1025 |
+
# print(boxes)
|
1026 |
+
# print(scores)
|
1027 |
+
# exit()
|
1028 |
+
|
1029 |
+
|
1030 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1031 |
+
f.write(json.dumps([total, correct]))
|
1032 |
+
if world_size > 1:
|
1033 |
+
torch.distributed.barrier()
|
1034 |
+
if rank == 0:
|
1035 |
+
total = 0
|
1036 |
+
correct = 0
|
1037 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1038 |
+
for rank_i in range(world_size):
|
1039 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1040 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1041 |
+
total += total_part
|
1042 |
+
correct += correct_part
|
1043 |
+
score = correct / total
|
1044 |
+
print("score:", score)
|
1045 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1046 |
+
pass
|
1047 |
+
else:
|
1048 |
+
score = 0.0
|
1049 |
+
if world_size > 1:
|
1050 |
+
torch.distributed.barrier()
|
1051 |
+
return score
|
1052 |
+
|
1053 |
+
|
1054 |
+
|
1055 |
+
# def preprocess_visual_info(Text):
|
1056 |
+
# text = Text.split(" ")
|
1057 |
+
# for is_idx, t in enumerate(text):
|
1058 |
+
# if t == "is":
|
1059 |
+
# break
|
1060 |
+
# the_idx = is_idx
|
1061 |
+
# while text[the_idx] != "the":
|
1062 |
+
# the_idx -= 1
|
1063 |
+
# obj_A = " ".join(text[the_idx+1:is_idx])
|
1064 |
+
# second_the_idx = len(text) - 1
|
1065 |
+
# while text[second_the_idx] != "the":
|
1066 |
+
# second_the_idx -= 1
|
1067 |
+
# obj_B = " ".join(text[second_the_idx+1:])
|
1068 |
+
# visual_obj_A = f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
|
1069 |
+
# visual_obj_B = f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
|
1070 |
+
# Text = Text.replace(obj_A, f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
|
1071 |
+
# Text = Text.replace(obj_B, f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
|
1072 |
+
# return Text, obj_A, obj_B, visual_obj_A, visual_obj_B
|
1073 |
+
|
1074 |
+
|
1075 |
+
def preprocess_visual_info(Text):
|
1076 |
+
text = Text.split(" ")
|
1077 |
+
for is_idx, t in enumerate(text):
|
1078 |
+
if t == "is":
|
1079 |
+
break
|
1080 |
+
the_idx = is_idx
|
1081 |
+
while text[the_idx] != "the":
|
1082 |
+
the_idx -= 1
|
1083 |
+
obj_A = " ".join(text[the_idx+1:is_idx])
|
1084 |
+
second_the_idx = len(text) - 1
|
1085 |
+
while text[second_the_idx] != "the":
|
1086 |
+
second_the_idx -= 1
|
1087 |
+
obj_B = " ".join(text[second_the_idx+1:])
|
1088 |
+
relation = " ".join(text[is_idx+1:second_the_idx])
|
1089 |
+
visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
|
1090 |
+
visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
|
1091 |
+
Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
|
1092 |
+
return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
|
1093 |
+
|
1094 |
+
|
1095 |
+
|
1096 |
+
|
1097 |
+
def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
|
1098 |
+
assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
|
1099 |
+
encodings = tokenizer(
|
1100 |
+
prompt,
|
1101 |
+
padding="longest",
|
1102 |
+
truncation=True,
|
1103 |
+
return_tensors="pt",
|
1104 |
+
max_length=2000,
|
1105 |
+
)
|
1106 |
+
input_ids = encodings["input_ids"]
|
1107 |
+
attention_mask = encodings["attention_mask"]
|
1108 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1109 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1110 |
+
image_nums = [1] * len(input_ids)
|
1111 |
+
vision_x = batch_images.cuda()
|
1112 |
+
lang_x = input_ids.cuda()
|
1113 |
+
attention_mask = attention_mask.cuda()
|
1114 |
+
|
1115 |
+
model.debug_id = 0
|
1116 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
1117 |
+
outputs = model(
|
1118 |
+
vision_x=vision_x,
|
1119 |
+
lang_x=lang_x,
|
1120 |
+
attention_mask=attention_mask,
|
1121 |
+
labels=None,
|
1122 |
+
image_nums=image_nums,
|
1123 |
+
image_start_index_list=image_start_index_list,
|
1124 |
+
added_bbox_list=visual_box_list,
|
1125 |
+
add_box=visual_box_list is not None,
|
1126 |
+
relations=None,
|
1127 |
+
debug_mode=False,
|
1128 |
+
)
|
1129 |
+
boxes = outputs["boxes"]
|
1130 |
+
scores = outputs["scores"]
|
1131 |
+
if debug:
|
1132 |
+
import pdb; pdb.set_trace()
|
1133 |
+
if return_all:
|
1134 |
+
return boxes, scores
|
1135 |
+
if len(scores) == 0:
|
1136 |
+
return None, None
|
1137 |
+
else:
|
1138 |
+
return boxes[scores.argmax()], scores.max()
|
1139 |
+
|
1140 |
+
|
1141 |
+
def evaluate_aro(
|
1142 |
+
model,
|
1143 |
+
tokenizer,
|
1144 |
+
image_processor,
|
1145 |
+
vis_embed_size=None,
|
1146 |
+
rank=0,
|
1147 |
+
world_size=1,
|
1148 |
+
id=0,
|
1149 |
+
add_visual=True,
|
1150 |
+
subset=False,
|
1151 |
+
choose_left_right=False,
|
1152 |
+
):
|
1153 |
+
# os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
|
1154 |
+
dataset_name = "aro"
|
1155 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1156 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1157 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1158 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1159 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1160 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1161 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
1162 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
1163 |
+
model.eval().cuda()
|
1164 |
+
total = 0
|
1165 |
+
n_top1 = 0
|
1166 |
+
n_top5 = 0
|
1167 |
+
from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
|
1168 |
+
vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
|
1169 |
+
if subset:
|
1170 |
+
subset_idx = json.load(open("aro_subset.json"))
|
1171 |
+
pbar = tqdm(subset_idx, disable=(rank != 0))
|
1172 |
+
else:
|
1173 |
+
pbar = tqdm(vgr_dataset, disable=(rank != 0))
|
1174 |
+
for ii, sample in enumerate(pbar):
|
1175 |
+
if subset:
|
1176 |
+
ORI_IDX = int(sample)
|
1177 |
+
sample = vgr_dataset[sample]
|
1178 |
+
if ii % world_size != rank:
|
1179 |
+
continue
|
1180 |
+
image = sample["image_options"][0]
|
1181 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
|
1182 |
+
image = image.resize((224, 224))
|
1183 |
+
|
1184 |
+
text = sample["caption_options"][1] # 1 is true caption
|
1185 |
+
# text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
|
1186 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1187 |
+
text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
|
1188 |
+
|
1189 |
+
|
1190 |
+
first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
|
1191 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
|
1192 |
+
first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
|
1193 |
+
|
1194 |
+
if first_box is None:
|
1195 |
+
text_A = "the " + obj_A
|
1196 |
+
added_bbox_list = None
|
1197 |
+
else:
|
1198 |
+
text_A = visual_obj_A
|
1199 |
+
added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
1200 |
+
|
1201 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
|
1202 |
+
pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
|
1203 |
+
prebox_token_id, return_all=True)
|
1204 |
+
|
1205 |
+
if pre_boxes is None:
|
1206 |
+
pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
|
1207 |
+
pre_scores = [1.0]
|
1208 |
+
|
1209 |
+
logits_list = []
|
1210 |
+
# pre_boxes = [pre_boxes[0]]
|
1211 |
+
# pre_scores = [pre_scores[0]]
|
1212 |
+
for pre_box, pre_score in zip(pre_boxes, pre_scores):
|
1213 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
|
1214 |
+
|
1215 |
+
encodings = tokenizer(
|
1216 |
+
prompt,
|
1217 |
+
padding="longest",
|
1218 |
+
truncation=True,
|
1219 |
+
return_tensors="pt",
|
1220 |
+
max_length=512,
|
1221 |
+
)
|
1222 |
+
input_ids = encodings["input_ids"]
|
1223 |
+
attention_mask = encodings["attention_mask"]
|
1224 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1225 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1226 |
+
image_nums = [1] * len(input_ids)
|
1227 |
+
vision_x = batch_images.cuda()
|
1228 |
+
lang_x = input_ids.cuda()
|
1229 |
+
attention_mask = attention_mask.cuda()
|
1230 |
+
labels = lang_x.clone()
|
1231 |
+
added_bbox_list = None
|
1232 |
+
if add_visual:
|
1233 |
+
added_bbox_list = []
|
1234 |
+
if first_box is not None:
|
1235 |
+
added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
|
1236 |
+
if pre_box is not None:
|
1237 |
+
added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
|
1238 |
+
if added_bbox_list is not None and len(added_bbox_list) == 0:
|
1239 |
+
added_bbox_list = None
|
1240 |
+
|
1241 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1242 |
+
outputs = model(
|
1243 |
+
vision_x=vision_x,
|
1244 |
+
lang_x=lang_x,
|
1245 |
+
attention_mask=attention_mask,
|
1246 |
+
labels=labels,
|
1247 |
+
image_nums=image_nums,
|
1248 |
+
image_start_index_list=image_start_index_list,
|
1249 |
+
added_bbox_list=added_bbox_list,
|
1250 |
+
add_box=added_bbox_list is not None,
|
1251 |
+
relations=None,
|
1252 |
+
)
|
1253 |
+
logits_list.append([pre_score, outputs.logits])
|
1254 |
+
pre_scores = np.array([x[0] for x in logits_list])
|
1255 |
+
final_probs = 0.0
|
1256 |
+
for score, (_, logits) in zip(pre_scores, logits_list):
|
1257 |
+
final_probs += score * logits.softmax(-1)
|
1258 |
+
assert input_ids.shape[:2] == final_probs.shape[:2]
|
1259 |
+
_rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, obj_B, topk=5)
|
1260 |
+
if is_top1:
|
1261 |
+
n_top1 += 1
|
1262 |
+
if is_top5:
|
1263 |
+
n_top5 += 1
|
1264 |
+
total += 1
|
1265 |
+
pbar.set_description(f"acc@top1: {n_top1 / total:.4f} | acc@top5: {n_top5 / total:.4f} | {_rank}")
|
1266 |
+
|
1267 |
+
|
1268 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1269 |
+
f.write(json.dumps([total, n_top1, n_top5]))
|
1270 |
+
if world_size > 1:
|
1271 |
+
torch.distributed.barrier()
|
1272 |
+
if rank == 0:
|
1273 |
+
total = 0
|
1274 |
+
n_top1 = 0
|
1275 |
+
n_top5 = 0
|
1276 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1277 |
+
for rank_i in range(world_size):
|
1278 |
+
[total_part, n_top1_part, n_top5_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1279 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1280 |
+
total += total_part
|
1281 |
+
n_top1 += n_top1_part
|
1282 |
+
n_top5 += n_top5_part
|
1283 |
+
acc_top1 = n_top1 / total
|
1284 |
+
acc_top5 = n_top5 / total
|
1285 |
+
print("acc_top1:", acc_top1, "acc_top5:", acc_top5, "total:", total)
|
1286 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc_top1}_{acc_top5}_{total}_{subset}"), "w") as f:
|
1287 |
+
pass
|
1288 |
+
else:
|
1289 |
+
score = 0.0
|
1290 |
+
if world_size > 1:
|
1291 |
+
torch.distributed.barrier()
|
1292 |
+
return score
|
1293 |
+
|
1294 |
+
|
1295 |
+
def evaluate_pisc(
|
1296 |
+
model,
|
1297 |
+
tokenizer,
|
1298 |
+
image_processor,
|
1299 |
+
batch_size,
|
1300 |
+
tsvfile,
|
1301 |
+
max_generation_length=20,
|
1302 |
+
num_beams=3,
|
1303 |
+
length_penalty=-2.0,
|
1304 |
+
device=-1,
|
1305 |
+
vis_embed_size=None,
|
1306 |
+
rank=0,
|
1307 |
+
world_size=1,
|
1308 |
+
id=0,
|
1309 |
+
add_visual=True,
|
1310 |
+
):
|
1311 |
+
from open_flamingo.train.instruction_template import PISC_TEMPLATES
|
1312 |
+
dataset_name = "pisc"
|
1313 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1314 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1315 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1316 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1317 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1318 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1319 |
+
model.train().cuda()
|
1320 |
+
|
1321 |
+
dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
|
1322 |
+
pbar = tqdm(dataset, disable=(rank != 0))
|
1323 |
+
|
1324 |
+
rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
|
1325 |
+
rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
|
1326 |
+
gt = []
|
1327 |
+
pred_scores = []
|
1328 |
+
for III, sample in enumerate(pbar):
|
1329 |
+
if III % world_size != rank:
|
1330 |
+
continue
|
1331 |
+
image_path, dataset, data = sample
|
1332 |
+
image = Image.open(image_path)
|
1333 |
+
size = image_processor.transforms[0].size
|
1334 |
+
image = image.resize((size, size))
|
1335 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1336 |
+
boxA = data[0]
|
1337 |
+
boxB = data[1]
|
1338 |
+
gt_relation = data[2]
|
1339 |
+
losses = []
|
1340 |
+
for i_rel, option_rel in enumerate(rel_id_to_type):
|
1341 |
+
text = PISC_TEMPLATES[0].format(relation=option_rel)
|
1342 |
+
added_bbox = [
|
1343 |
+
torch.tensor([boxA]).cuda(),
|
1344 |
+
torch.tensor([boxB]).cuda(),
|
1345 |
+
]
|
1346 |
+
caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
|
1347 |
+
encodings = tokenizer(
|
1348 |
+
caption,
|
1349 |
+
padding="longest",
|
1350 |
+
truncation=True,
|
1351 |
+
return_tensors="pt",
|
1352 |
+
max_length=2000,
|
1353 |
+
)
|
1354 |
+
input_ids = encodings["input_ids"]
|
1355 |
+
attention_mask = encodings["attention_mask"]
|
1356 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1357 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1358 |
+
image_nums = [1] * len(input_ids)
|
1359 |
+
vision_x = batch_images.cuda()
|
1360 |
+
lang_x = input_ids.cuda()
|
1361 |
+
attention_mask = attention_mask.cuda()
|
1362 |
+
|
1363 |
+
labels = lang_x.clone()
|
1364 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
1365 |
+
if add_visual:
|
1366 |
+
# endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
|
1367 |
+
# endofattr_next_token_index[1] += 1
|
1368 |
+
# endofattr_next_token_id = labels[endofattr_next_token_index]
|
1369 |
+
# </obj><visual><box></attr>NEXT_WORD
|
1370 |
+
# </obj> predict NEXT_WORD
|
1371 |
+
# <visual><box></attr> predict nothing
|
1372 |
+
labels[labels == visual_token_id] = -100
|
1373 |
+
labels[labels == box_token_id] = -100
|
1374 |
+
labels[labels == endofattr_token_id] = -100
|
1375 |
+
# labels[endofattr_next_token_index] = -100
|
1376 |
+
labels[:, 0] = -100
|
1377 |
+
answer_token_id = tokenizer(" Answer").input_ids[0]
|
1378 |
+
answer_token_loc = (input_ids == answer_token_id).nonzero()
|
1379 |
+
for batch_idx, idx in answer_token_loc:
|
1380 |
+
labels[batch_idx][:idx+2] = -100
|
1381 |
+
|
1382 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1383 |
+
outputs = model(
|
1384 |
+
vision_x=vision_x,
|
1385 |
+
lang_x=lang_x,
|
1386 |
+
attention_mask=attention_mask,
|
1387 |
+
labels=labels,
|
1388 |
+
image_nums=image_nums,
|
1389 |
+
image_start_index_list=image_start_index_list,
|
1390 |
+
added_bbox_list=added_bbox,
|
1391 |
+
add_box=added_bbox is not None,
|
1392 |
+
)
|
1393 |
+
loss_total = outputs.loss.reshape(labels.shape[0], -1)
|
1394 |
+
loss = loss_total.sum() / (loss_total != 0).sum()
|
1395 |
+
losses.append(loss.item())
|
1396 |
+
pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
|
1397 |
+
gt.append(rel_type_to_id[gt_relation])
|
1398 |
+
gt = np.array(gt)
|
1399 |
+
pred_scores = np.array(pred_scores)
|
1400 |
+
pred = pred_scores.argmax(1)
|
1401 |
+
|
1402 |
+
|
1403 |
+
print("total num:", len(gt))
|
1404 |
+
recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
|
1405 |
+
print("recalls:", recalls)
|
1406 |
+
|
1407 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1408 |
+
f.write(json.dumps([gt.tolist(), pred.tolist()]))
|
1409 |
+
if world_size > 1:
|
1410 |
+
torch.distributed.barrier()
|
1411 |
+
if rank == 0:
|
1412 |
+
gt = []
|
1413 |
+
pred = []
|
1414 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1415 |
+
for rank_i in range(world_size):
|
1416 |
+
[gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1417 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1418 |
+
gt.extend(gt_part)
|
1419 |
+
pred.extend(pred_part)
|
1420 |
+
print("total num:", len(gt))
|
1421 |
+
recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
|
1422 |
+
print("recalls:", recalls)
|
1423 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
|
1424 |
+
f.write(f"{gt}\n")
|
1425 |
+
f.write(f"{pred}\n")
|
1426 |
+
f.write(f"{recalls}\n")
|
1427 |
+
score = 0.0
|
1428 |
+
if world_size > 1:
|
1429 |
+
torch.distributed.barrier()
|
1430 |
+
return score
|
1431 |
+
|
1432 |
+
|
1433 |
+
|
1434 |
+
if __name__ == "__main__":
|
1435 |
+
main()
|
multimodal/build/lib/open_flamingo/eval/evaluate_debug.py
ADDED
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from math import ceil
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable
|
9 |
+
import time
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
import more_itertools
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
16 |
+
from eval_datasets import VQADataset, GQADataset
|
17 |
+
from tqdm import tqdm
|
18 |
+
from collections import Counter
|
19 |
+
|
20 |
+
from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
|
21 |
+
from open_flamingo.eval.classification import (
|
22 |
+
compute_per_sample_probs,
|
23 |
+
compute_per_sample_loss,
|
24 |
+
)
|
25 |
+
from open_flamingo.eval.imagenet_utils import (
|
26 |
+
openai_imagenet_classnames,
|
27 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL,
|
28 |
+
)
|
29 |
+
|
30 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
31 |
+
from PIL import Image
|
32 |
+
from io import BytesIO
|
33 |
+
import base64
|
34 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
35 |
+
import string
|
36 |
+
from lavis.datasets.builders import load_dataset
|
37 |
+
|
38 |
+
|
39 |
+
def get_iou(box1, box2):
|
40 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
41 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
42 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
43 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
44 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
45 |
+
union = area_box1 + area_box2 - intersection
|
46 |
+
iou = intersection / union if union > 0 else 0
|
47 |
+
return iou
|
48 |
+
|
49 |
+
def expand2square(pil_img, background_color):
|
50 |
+
width, height = pil_img.size
|
51 |
+
if width == height:
|
52 |
+
return pil_img
|
53 |
+
elif width > height:
|
54 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
55 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
56 |
+
return result
|
57 |
+
else:
|
58 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
59 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
60 |
+
return result
|
61 |
+
|
62 |
+
parser = argparse.ArgumentParser()
|
63 |
+
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
|
64 |
+
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
|
65 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
66 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
67 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
68 |
+
parser.add_argument(
|
69 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
70 |
+
)
|
71 |
+
|
72 |
+
# Trial arguments
|
73 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
74 |
+
parser.add_argument(
|
75 |
+
"--num_trials",
|
76 |
+
type=int,
|
77 |
+
default=1,
|
78 |
+
help="Number of trials to run for each shot using different demonstrations",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--trial_seeds",
|
82 |
+
nargs="+",
|
83 |
+
default=[0],
|
84 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
|
88 |
+
)
|
89 |
+
|
90 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
91 |
+
|
92 |
+
# Per-dataset evaluation flags
|
93 |
+
parser.add_argument(
|
94 |
+
"--eval_coco",
|
95 |
+
action="store_true",
|
96 |
+
default=False,
|
97 |
+
help="Whether to evaluate on COCO.",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--eval_vqav2",
|
101 |
+
action="store_true",
|
102 |
+
default=False,
|
103 |
+
help="Whether to evaluate on VQAV2.",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--eval_ok_vqa",
|
107 |
+
action="store_true",
|
108 |
+
default=False,
|
109 |
+
help="Whether to evaluate on OK-VQA.",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--eval_imagenet",
|
113 |
+
action="store_true",
|
114 |
+
default=False,
|
115 |
+
help="Whether to evaluate on ImageNet.",
|
116 |
+
)
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--eval_flickr30",
|
120 |
+
action="store_true",
|
121 |
+
default=False,
|
122 |
+
help="Whether to evaluate on Flickr30.",
|
123 |
+
)
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
"--eval_refcoco",
|
127 |
+
action="store_true",
|
128 |
+
default=False,
|
129 |
+
help="Whether to evaluate on RefCOCO.",
|
130 |
+
)
|
131 |
+
|
132 |
+
# Dataset arguments
|
133 |
+
|
134 |
+
## Flickr30 Dataset
|
135 |
+
parser.add_argument(
|
136 |
+
"--flickr_image_dir_path",
|
137 |
+
type=str,
|
138 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
139 |
+
default=None,
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--flickr_annotations_json_path",
|
143 |
+
type=str,
|
144 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
145 |
+
default=None,
|
146 |
+
)
|
147 |
+
|
148 |
+
## COCO Dataset
|
149 |
+
parser.add_argument(
|
150 |
+
"--coco_image_dir_path",
|
151 |
+
type=str,
|
152 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
153 |
+
default=None,
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--coco_annotations_json_path",
|
157 |
+
type=str,
|
158 |
+
default=None,
|
159 |
+
)
|
160 |
+
|
161 |
+
## VQAV2 Dataset
|
162 |
+
parser.add_argument(
|
163 |
+
"--vqav2_image_dir_path",
|
164 |
+
type=str,
|
165 |
+
default=None,
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--vqav2_questions_json_path",
|
169 |
+
type=str,
|
170 |
+
default=None,
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--vqav2_annotations_json_path",
|
174 |
+
type=str,
|
175 |
+
default=None,
|
176 |
+
)
|
177 |
+
|
178 |
+
## OK-VQA Dataset
|
179 |
+
parser.add_argument(
|
180 |
+
"--ok_vqa_image_dir_path",
|
181 |
+
type=str,
|
182 |
+
help="Path to the vqav2/train2014 directory.",
|
183 |
+
default=None,
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--ok_vqa_questions_json_path",
|
187 |
+
type=str,
|
188 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
189 |
+
default=None,
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--ok_vqa_annotations_json_path",
|
193 |
+
type=str,
|
194 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
195 |
+
default=None,
|
196 |
+
)
|
197 |
+
|
198 |
+
## Imagenet dataset
|
199 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
200 |
+
|
201 |
+
## RefCOCO dataset
|
202 |
+
parser.add_argument("--refcoco_tsvfile", type=str, default=None)
|
203 |
+
|
204 |
+
parser.add_argument(
|
205 |
+
"--location_token_num",
|
206 |
+
default=1000,
|
207 |
+
type=int,
|
208 |
+
)
|
209 |
+
# distributed training
|
210 |
+
parser.add_argument(
|
211 |
+
"--dist-url",
|
212 |
+
default="env://",
|
213 |
+
type=str,
|
214 |
+
help="url used to set up distributed training",
|
215 |
+
)
|
216 |
+
parser.add_argument(
|
217 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--horovod",
|
221 |
+
default=False,
|
222 |
+
action="store_true",
|
223 |
+
help="Use horovod for distributed training.",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--no-set-device-rank",
|
227 |
+
default=False,
|
228 |
+
action="store_true",
|
229 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--dist",
|
233 |
+
default=False,
|
234 |
+
action="store_true",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--lora",
|
238 |
+
default=False,
|
239 |
+
action="store_true",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--lora_r",
|
243 |
+
default=16,
|
244 |
+
type=int,
|
245 |
+
required=False,
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--legacy",
|
249 |
+
default=False,
|
250 |
+
action="store_true",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--special",
|
254 |
+
default=False,
|
255 |
+
action="store_true",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--id",
|
259 |
+
default=0,
|
260 |
+
type=int,
|
261 |
+
required=False,
|
262 |
+
)
|
263 |
+
|
264 |
+
parser.add_argument(
|
265 |
+
"--eval_gqa",
|
266 |
+
default=False,
|
267 |
+
action="store_true",
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--use_sam",
|
271 |
+
default=None,
|
272 |
+
type=str,
|
273 |
+
required=False,
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--add_visual_token",
|
277 |
+
default=False,
|
278 |
+
action="store_true",
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"--use_format_v2",
|
282 |
+
default=False,
|
283 |
+
action="store_true",
|
284 |
+
)
|
285 |
+
|
286 |
+
|
287 |
+
class OKVQAPostProcess():
|
288 |
+
def __init__(self):
|
289 |
+
self._lemmatizer = None
|
290 |
+
|
291 |
+
def _lemmatize(self, answers):
|
292 |
+
def apply(answer):
|
293 |
+
doc = self.lemmatizer(answer)
|
294 |
+
|
295 |
+
words = []
|
296 |
+
for token in doc:
|
297 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
298 |
+
words.append(token.lemma_)
|
299 |
+
else:
|
300 |
+
words.append(token.text)
|
301 |
+
answer = " ".join(words)
|
302 |
+
|
303 |
+
return answer
|
304 |
+
|
305 |
+
return [apply(answer) for answer in answers]
|
306 |
+
|
307 |
+
@property
|
308 |
+
def lemmatizer(self):
|
309 |
+
if self._lemmatizer is None:
|
310 |
+
try:
|
311 |
+
import spacy
|
312 |
+
|
313 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
314 |
+
except ImportError:
|
315 |
+
logging.error(
|
316 |
+
"""
|
317 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
318 |
+
python -m spacy download en_core_web_sm
|
319 |
+
OR
|
320 |
+
import spacy.cli
|
321 |
+
spacy.cli.download("en_core_web_sm")
|
322 |
+
"""
|
323 |
+
)
|
324 |
+
exit(1)
|
325 |
+
|
326 |
+
return self._lemmatizer
|
327 |
+
|
328 |
+
|
329 |
+
def main():
|
330 |
+
args = parser.parse_args()
|
331 |
+
if args.dist:
|
332 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
333 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
334 |
+
device_id = init_distributed_device(args)
|
335 |
+
else:
|
336 |
+
args.rank = 0
|
337 |
+
args.world_size = 1
|
338 |
+
print(f"rank: {args.rank} world_size: {args.world_size}")
|
339 |
+
|
340 |
+
if "sam" in args.checkpoint_path:
|
341 |
+
args.use_sam = "vit_l"
|
342 |
+
|
343 |
+
args.add_visual_token = True
|
344 |
+
if "lora" in args.checkpoint_path:
|
345 |
+
args.lora = True
|
346 |
+
|
347 |
+
|
348 |
+
args.add_pe = False
|
349 |
+
args.add_box = False
|
350 |
+
args.relation = False
|
351 |
+
if "debug" in args.checkpoint_path:
|
352 |
+
# args.add_pe = True
|
353 |
+
args.add_box = True
|
354 |
+
if "box" in args.checkpoint_path:
|
355 |
+
args.add_box = True
|
356 |
+
if "pe" in args.checkpoint_path:
|
357 |
+
args.add_pe = True
|
358 |
+
if "rel" in args.checkpoint_path:
|
359 |
+
args.relation = True
|
360 |
+
args.add_pe = False
|
361 |
+
if "previsual" in args.checkpoint_path:
|
362 |
+
args.use_format_v2 = True
|
363 |
+
args.relation = False
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
# load model
|
368 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
369 |
+
args.vision_encoder_path,
|
370 |
+
args.vision_encoder_pretrained,
|
371 |
+
args.lm_path,
|
372 |
+
args.lm_tokenizer_path,
|
373 |
+
location_token_num=args.location_token_num,
|
374 |
+
lora=args.lora,
|
375 |
+
lora_r=16,
|
376 |
+
use_sam=args.use_sam,
|
377 |
+
add_visual_token=args.add_visual_token,
|
378 |
+
use_format_v2=args.use_format_v2,
|
379 |
+
add_box=args.add_box,
|
380 |
+
add_pe=args.add_pe,
|
381 |
+
add_relation=args.relation,
|
382 |
+
)
|
383 |
+
flamingo.use_format_v2 = args.use_format_v2
|
384 |
+
if args.special:
|
385 |
+
flamingo.special = True
|
386 |
+
else:
|
387 |
+
flamingo.special = False
|
388 |
+
if args.legacy:
|
389 |
+
flamingo.legacy = True
|
390 |
+
print("use legacy evaluation")
|
391 |
+
flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
|
392 |
+
flamingo.expr_name = args.checkpoint_path.split("/")[-2]
|
393 |
+
if args.rank == 0:
|
394 |
+
print("legacy", True if hasattr(flamingo, "legacy") else False)
|
395 |
+
print("step:", flamingo.step_num)
|
396 |
+
print("expr:", flamingo.expr_name)
|
397 |
+
print("use format v2:", flamingo.use_format_v2)
|
398 |
+
print(args)
|
399 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
400 |
+
model_state_dict = {}
|
401 |
+
for key in checkpoint["model_state_dict"].keys():
|
402 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
403 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
404 |
+
# previous checkpoint has some unnecessary weights
|
405 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
406 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
407 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
408 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
409 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
410 |
+
results = defaultdict(list)
|
411 |
+
if args.eval_coco:
|
412 |
+
print("Evaluating on COCO...")
|
413 |
+
for shot in args.shots:
|
414 |
+
scores = []
|
415 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
416 |
+
cider_score = evaluate_coco_flickr(
|
417 |
+
model=flamingo,
|
418 |
+
tokenizer=tokenizer,
|
419 |
+
image_processor=image_processor,
|
420 |
+
batch_size=args.batch_size,
|
421 |
+
image_dir_path=args.coco_image_dir_path,
|
422 |
+
annotations_json_path=args.coco_annotations_json_path,
|
423 |
+
device=args.device,
|
424 |
+
seed=seed,
|
425 |
+
vis_embed_size=vis_embed_size,
|
426 |
+
rank=args.rank,
|
427 |
+
world_size=args.world_size,
|
428 |
+
id=args.id,
|
429 |
+
)
|
430 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
431 |
+
scores.append(cider_score)
|
432 |
+
print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
|
433 |
+
results["coco"].append(
|
434 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
435 |
+
)
|
436 |
+
|
437 |
+
if args.eval_ok_vqa:
|
438 |
+
print("Evaluating on OK-VQA...")
|
439 |
+
for shot in args.shots:
|
440 |
+
scores = []
|
441 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
442 |
+
ok_vqa_score = evaluate_vqa(
|
443 |
+
model=flamingo,
|
444 |
+
tokenizer=tokenizer,
|
445 |
+
image_processor=image_processor,
|
446 |
+
batch_size=args.batch_size,
|
447 |
+
image_dir_path=args.ok_vqa_image_dir_path,
|
448 |
+
questions_json_path=args.ok_vqa_questions_json_path,
|
449 |
+
annotations_json_path=args.ok_vqa_annotations_json_path,
|
450 |
+
vqa_dataset="ok_vqa",
|
451 |
+
vis_embed_size=vis_embed_size,
|
452 |
+
rank=args.rank,
|
453 |
+
world_size=args.world_size,
|
454 |
+
id=args.id,
|
455 |
+
)
|
456 |
+
results["ok_vqa"].append(
|
457 |
+
{"shots": shot, "score": ok_vqa_score}
|
458 |
+
)
|
459 |
+
|
460 |
+
if args.eval_vqav2:
|
461 |
+
print("Evaluating on VQAv2...")
|
462 |
+
for shot in args.shots:
|
463 |
+
scores = []
|
464 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
465 |
+
vqa_score = evaluate_vqa(
|
466 |
+
model=flamingo,
|
467 |
+
tokenizer=tokenizer,
|
468 |
+
image_processor=image_processor,
|
469 |
+
batch_size=args.batch_size,
|
470 |
+
image_dir_path=args.vqav2_image_dir_path,
|
471 |
+
questions_json_path=args.vqav2_questions_json_path,
|
472 |
+
annotations_json_path=args.vqav2_annotations_json_path,
|
473 |
+
vqa_dataset="vqa",
|
474 |
+
vis_embed_size=vis_embed_size,
|
475 |
+
rank=args.rank,
|
476 |
+
world_size=args.world_size,
|
477 |
+
id=args.id,
|
478 |
+
)
|
479 |
+
results["vqav2"].append(
|
480 |
+
{"shots": shot, "score": vqa_score}
|
481 |
+
)
|
482 |
+
|
483 |
+
if args.eval_gqa:
|
484 |
+
print("Evaluating on GQA...")
|
485 |
+
for shot in args.shots:
|
486 |
+
scores = []
|
487 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
488 |
+
vqa_score = evaluate_vqa(
|
489 |
+
model=flamingo,
|
490 |
+
tokenizer=tokenizer,
|
491 |
+
image_processor=image_processor,
|
492 |
+
batch_size=args.batch_size,
|
493 |
+
vqa_dataset="gqa",
|
494 |
+
vis_embed_size=vis_embed_size,
|
495 |
+
rank=args.rank,
|
496 |
+
world_size=args.world_size,
|
497 |
+
id=args.id,
|
498 |
+
)
|
499 |
+
results["gqa"].append(
|
500 |
+
{"shots": shot, "score": vqa_score}
|
501 |
+
)
|
502 |
+
|
503 |
+
if args.eval_imagenet:
|
504 |
+
print("Evaluating on ImageNet...")
|
505 |
+
for shot in args.shots:
|
506 |
+
scores = []
|
507 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
508 |
+
imagenet_score = evaluate_imagenet(
|
509 |
+
model=flamingo,
|
510 |
+
tokenizer=tokenizer,
|
511 |
+
image_processor=image_processor,
|
512 |
+
batch_size=args.batch_size,
|
513 |
+
num_samples=args.num_samples,
|
514 |
+
num_shots=shot,
|
515 |
+
device=args.device,
|
516 |
+
seed=seed,
|
517 |
+
imagenet_root=args.imagenet_root,
|
518 |
+
)
|
519 |
+
print(
|
520 |
+
f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
|
521 |
+
)
|
522 |
+
scores.append(imagenet_score)
|
523 |
+
print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
|
524 |
+
results["imagenet"].append(
|
525 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
526 |
+
)
|
527 |
+
|
528 |
+
if args.eval_refcoco:
|
529 |
+
print("Evaluating on RefCOCO...")
|
530 |
+
refcoco_score = evaluate_refcoco(
|
531 |
+
model=flamingo,
|
532 |
+
tokenizer=tokenizer,
|
533 |
+
image_processor=image_processor,
|
534 |
+
batch_size=args.batch_size,
|
535 |
+
device=args.device,
|
536 |
+
tsvfile=args.refcoco_tsvfile,
|
537 |
+
vis_embed_size=vis_embed_size,
|
538 |
+
rank=args.rank,
|
539 |
+
world_size=args.world_size,
|
540 |
+
id=args.id,
|
541 |
+
)
|
542 |
+
results["refcoco"].append(
|
543 |
+
{"score": refcoco_score}
|
544 |
+
)
|
545 |
+
|
546 |
+
def prepare_batch_images(batch, image_processor):
|
547 |
+
batch_images = None
|
548 |
+
for b in batch:
|
549 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
550 |
+
if batch_images is None:
|
551 |
+
batch_images = b_image
|
552 |
+
else:
|
553 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
554 |
+
return batch_images
|
555 |
+
|
556 |
+
def get_outputs(
|
557 |
+
model,
|
558 |
+
batch_images,
|
559 |
+
attention_mask,
|
560 |
+
max_generation_length,
|
561 |
+
min_generation_length,
|
562 |
+
num_beams,
|
563 |
+
length_penalty,
|
564 |
+
input_ids,
|
565 |
+
image_start_index_list=None,
|
566 |
+
image_nums=None,
|
567 |
+
bad_words_ids=None,
|
568 |
+
):
|
569 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
570 |
+
outputs = model.generate(
|
571 |
+
batch_images,
|
572 |
+
input_ids,
|
573 |
+
attention_mask=attention_mask,
|
574 |
+
max_new_tokens=max_generation_length,
|
575 |
+
min_length=min_generation_length,
|
576 |
+
num_beams=num_beams,
|
577 |
+
length_penalty=length_penalty,
|
578 |
+
image_start_index_list=image_start_index_list,
|
579 |
+
image_nums=image_nums,
|
580 |
+
bad_words_ids=bad_words_ids,
|
581 |
+
)
|
582 |
+
|
583 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
584 |
+
return outputs
|
585 |
+
|
586 |
+
|
587 |
+
def evaluate_coco_flickr(
|
588 |
+
model,
|
589 |
+
tokenizer,
|
590 |
+
image_processor,
|
591 |
+
batch_size,
|
592 |
+
image_dir_path,
|
593 |
+
annotations_json_path,
|
594 |
+
seed=42,
|
595 |
+
max_generation_length=20,
|
596 |
+
num_beams=1,
|
597 |
+
length_penalty=-2.0,
|
598 |
+
device=-1,
|
599 |
+
is_flickr=False,
|
600 |
+
vis_embed_size=None,
|
601 |
+
rank=0,
|
602 |
+
world_size=1,
|
603 |
+
id=0,
|
604 |
+
):
|
605 |
+
"""Evaluate a model on COCO dataset.
|
606 |
+
|
607 |
+
Args:
|
608 |
+
model (nn.Module): model to evaluate
|
609 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
610 |
+
image_processor : image processor for the model
|
611 |
+
batch_size (int): batch size
|
612 |
+
image_dir_path (str, optional): path to the directory containing the images.
|
613 |
+
annotations_json_path (str, optional): path to the json file containing the annotations.
|
614 |
+
seed (int, optional): seed for random number generator. Defaults to 42.
|
615 |
+
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
|
616 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
617 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
618 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
|
619 |
+
query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
|
620 |
+
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
|
621 |
+
device (int, optional): device to use. Defaults to -1.
|
622 |
+
num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
|
623 |
+
is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
float: CIDEr score
|
627 |
+
|
628 |
+
"""
|
629 |
+
# eval_dataset = COCOFlickrDataset(
|
630 |
+
# image_dir_path=image_dir_path,
|
631 |
+
# annotations_path=annotations_json_path,
|
632 |
+
# is_flickr=is_flickr,
|
633 |
+
# )
|
634 |
+
coco_dataset = load_dataset("coco_caption")
|
635 |
+
eval_dataset = coco_dataset["test"]
|
636 |
+
|
637 |
+
|
638 |
+
model.eval().cuda()
|
639 |
+
predictions = defaultdict()
|
640 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
641 |
+
# if "peft" in lang_encoder_name:
|
642 |
+
# lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
643 |
+
try:
|
644 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
645 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
646 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
647 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
648 |
+
except:
|
649 |
+
pass
|
650 |
+
|
651 |
+
def get_prompt(sample):
|
652 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
653 |
+
|
654 |
+
tokenizer.padding_side = "left"
|
655 |
+
cnt = 0
|
656 |
+
if world_size > 1:
|
657 |
+
torch.distributed.barrier()
|
658 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
659 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
660 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
661 |
+
)):
|
662 |
+
if ii % world_size != rank:
|
663 |
+
continue
|
664 |
+
cnt += len(batch)
|
665 |
+
batch_images = prepare_batch_images(
|
666 |
+
batch=batch,
|
667 |
+
image_processor=image_processor,
|
668 |
+
).cuda()
|
669 |
+
batch_text = [get_prompt(s) for s in batch]
|
670 |
+
encodings = tokenizer(
|
671 |
+
batch_text,
|
672 |
+
padding="longest",
|
673 |
+
truncation=True,
|
674 |
+
return_tensors="pt",
|
675 |
+
max_length=2000,
|
676 |
+
)
|
677 |
+
input_ids = encodings["input_ids"].cuda()
|
678 |
+
attention_mask = encodings["attention_mask"].cuda()
|
679 |
+
skip_special_tokens = False
|
680 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
681 |
+
if rank == 0:
|
682 |
+
tqdm.write("use legacy model")
|
683 |
+
skip_special_tokens = True
|
684 |
+
for i in range(len(input_ids)):
|
685 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
686 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
687 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
688 |
+
input_ids[i, media_token_index] = pad_token_id
|
689 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
690 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
691 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
692 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
693 |
+
image_nums = [1] * len(input_ids)
|
694 |
+
if "llama" in lang_encoder_name:
|
695 |
+
attention_mask[input_ids == 0] = 0
|
696 |
+
outputs = get_outputs(
|
697 |
+
model=model,
|
698 |
+
batch_images=batch_images,
|
699 |
+
attention_mask=attention_mask,
|
700 |
+
max_generation_length=30,
|
701 |
+
min_generation_length=8,
|
702 |
+
num_beams=5,
|
703 |
+
length_penalty=0,
|
704 |
+
input_ids=input_ids,
|
705 |
+
image_start_index_list=image_start_index_list,
|
706 |
+
image_nums=image_nums,
|
707 |
+
)
|
708 |
+
new_predictions = [
|
709 |
+
postprocess_captioning_generation(out).replace('"', "")
|
710 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
711 |
+
]
|
712 |
+
# if rank == 0:
|
713 |
+
# tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
|
714 |
+
|
715 |
+
for i, sample in enumerate(batch):
|
716 |
+
predictions[int(sample["image_id"])] = {
|
717 |
+
"caption": new_predictions[i],
|
718 |
+
}
|
719 |
+
results_path = (
|
720 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
721 |
+
if is_flickr
|
722 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
723 |
+
)
|
724 |
+
with open(results_path, "w") as f:
|
725 |
+
f.write(
|
726 |
+
json.dumps(
|
727 |
+
[
|
728 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
729 |
+
for k in predictions
|
730 |
+
],
|
731 |
+
indent=2,
|
732 |
+
)
|
733 |
+
)
|
734 |
+
print("save to", results_path)
|
735 |
+
del predictions
|
736 |
+
time.sleep(10)
|
737 |
+
if world_size > 1:
|
738 |
+
torch.distributed.barrier()
|
739 |
+
if rank == 0:
|
740 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
741 |
+
predictions = []
|
742 |
+
for rank_i in range(world_size):
|
743 |
+
part_results_path = (
|
744 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
745 |
+
if is_flickr
|
746 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
747 |
+
)
|
748 |
+
print("load", part_results_path)
|
749 |
+
predictions.extend(json.load(open(part_results_path)))
|
750 |
+
os.remove(part_results_path)
|
751 |
+
print("num:", len(predictions))
|
752 |
+
results_path = (
|
753 |
+
f"flickrresults_{lang_encoder_name}.json"
|
754 |
+
if is_flickr
|
755 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
756 |
+
)
|
757 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
758 |
+
|
759 |
+
metrics = compute_cider(
|
760 |
+
result_path=results_path,
|
761 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
762 |
+
)
|
763 |
+
os.makedirs("eval_results", exist_ok=True)
|
764 |
+
acc = metrics["CIDEr"]
|
765 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
766 |
+
f.write(json.dumps(predictions, indent=2))
|
767 |
+
|
768 |
+
# delete the temporary file
|
769 |
+
os.remove(results_path)
|
770 |
+
else:
|
771 |
+
metrics = {}
|
772 |
+
metrics["CIDEr"] = 0.0
|
773 |
+
|
774 |
+
return metrics["CIDEr"]
|
775 |
+
|
776 |
+
|
777 |
+
def evaluate_vqa(
|
778 |
+
model,
|
779 |
+
tokenizer,
|
780 |
+
image_processor,
|
781 |
+
batch_size,
|
782 |
+
image_dir_path=None,
|
783 |
+
questions_json_path=None,
|
784 |
+
annotations_json_path=None,
|
785 |
+
vqa_dataset="vqa",
|
786 |
+
vis_embed_size=None,
|
787 |
+
rank=0,
|
788 |
+
world_size=1,
|
789 |
+
id=0,
|
790 |
+
):
|
791 |
+
"""
|
792 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
793 |
+
|
794 |
+
Args:
|
795 |
+
model (nn.Module): model to evaluate
|
796 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
797 |
+
image_processor : image processor for the model
|
798 |
+
batch_size (int): batch size
|
799 |
+
image_dir_path (str): path to image directory
|
800 |
+
questions_json_path (str): path to questions json file
|
801 |
+
annotations_json_path (str): path to annotations json file
|
802 |
+
seed (int, optional): random seed. Defaults to 42.
|
803 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
804 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
805 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
806 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
807 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
808 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
809 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
810 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
811 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
812 |
+
Returns:
|
813 |
+
float: accuracy score
|
814 |
+
"""
|
815 |
+
if world_size > 1:
|
816 |
+
torch.distributed.barrier()
|
817 |
+
if vqa_dataset == "gqa":
|
818 |
+
eval_dataset = GQADataset()
|
819 |
+
else:
|
820 |
+
eval_dataset = VQADataset(
|
821 |
+
image_dir_path=image_dir_path,
|
822 |
+
question_path=questions_json_path,
|
823 |
+
annotations_path=annotations_json_path,
|
824 |
+
vqa_dataset=vqa_dataset,
|
825 |
+
)
|
826 |
+
postprocessor = OKVQAPostProcess()
|
827 |
+
try:
|
828 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
829 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
830 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
831 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
832 |
+
except:
|
833 |
+
pass
|
834 |
+
def get_prompt(sample):
|
835 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
836 |
+
# return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
837 |
+
|
838 |
+
model.eval().cuda()
|
839 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
840 |
+
if "peft" in lang_encoder_name:
|
841 |
+
lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
842 |
+
predictions = []
|
843 |
+
tokenizer.padding_side = "left"
|
844 |
+
if world_size > 1:
|
845 |
+
torch.distributed.barrier()
|
846 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
847 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
|
848 |
+
)):
|
849 |
+
if ii % world_size != rank:
|
850 |
+
continue
|
851 |
+
batch_images = prepare_batch_images(
|
852 |
+
batch=batch,
|
853 |
+
image_processor=image_processor,
|
854 |
+
).cuda()
|
855 |
+
batch_text = [get_prompt(s) for s in batch]
|
856 |
+
encodings = tokenizer(
|
857 |
+
batch_text,
|
858 |
+
return_tensors="pt",
|
859 |
+
padding="longest",
|
860 |
+
truncation=True,
|
861 |
+
max_length=2000,
|
862 |
+
)
|
863 |
+
input_ids = encodings["input_ids"].cuda()
|
864 |
+
attention_mask = encodings["attention_mask"].cuda()
|
865 |
+
skip_special_tokens = True
|
866 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
867 |
+
if rank == 0:
|
868 |
+
tqdm.write("use legacy model")
|
869 |
+
for i in range(len(input_ids)):
|
870 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
871 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
872 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
873 |
+
input_ids[i, media_token_index] = pad_token_id
|
874 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
875 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
876 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
877 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
878 |
+
image_nums = [1] * len(input_ids)
|
879 |
+
if "llama" in lang_encoder_name:
|
880 |
+
attention_mask[input_ids == 0] = 0
|
881 |
+
outputs = get_outputs(
|
882 |
+
model=model,
|
883 |
+
batch_images=batch_images,
|
884 |
+
attention_mask=attention_mask,
|
885 |
+
max_generation_length=10,
|
886 |
+
min_generation_length=1,
|
887 |
+
num_beams=5,
|
888 |
+
length_penalty=0,
|
889 |
+
input_ids=input_ids,
|
890 |
+
image_start_index_list=image_start_index_list,
|
891 |
+
image_nums=image_nums,
|
892 |
+
)
|
893 |
+
# postprocess begin
|
894 |
+
new_predictions = [
|
895 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
896 |
+
]
|
897 |
+
if vqa_dataset == "ok_vqa":
|
898 |
+
new_predictions = postprocessor._lemmatize(new_predictions)
|
899 |
+
if model.special:
|
900 |
+
for i in range(len(new_predictions)):
|
901 |
+
for answer, _ in Counter(batch[i]['answers']).most_common():
|
902 |
+
if answer in new_predictions[i]:
|
903 |
+
new_predictions[i] = answer
|
904 |
+
break
|
905 |
+
if "cant" in new_predictions[i] and "no" == answer:
|
906 |
+
new_predictions[i] = answer
|
907 |
+
break
|
908 |
+
if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
|
909 |
+
new_predictions[i] = answer
|
910 |
+
break
|
911 |
+
|
912 |
+
# if rank == 0:
|
913 |
+
# tqdm.write(f"{image_nums} {image_start_index_list}")
|
914 |
+
# for i in range(1):
|
915 |
+
# tqdm.write(f"ID: {batch[i]['question_id']} | gt QA: {batch[i]['question']} {Counter(batch[i]['answers']).most_common()}")
|
916 |
+
# tqdm.write("prompt: " + tokenizer.decode(input_ids[i]))
|
917 |
+
# tqdm.write("model output: " + new_predictions[i])
|
918 |
+
|
919 |
+
predictions.extend(
|
920 |
+
[
|
921 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
922 |
+
for p, sample in zip(new_predictions, batch)
|
923 |
+
]
|
924 |
+
)
|
925 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
926 |
+
f.write(json.dumps(predictions))
|
927 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
928 |
+
|
929 |
+
time.sleep(10)
|
930 |
+
if world_size > 1:
|
931 |
+
torch.distributed.barrier()
|
932 |
+
if rank == 0:
|
933 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
934 |
+
predictions = []
|
935 |
+
for rank_i in range(world_size):
|
936 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
937 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
938 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
939 |
+
print("num:", len(predictions))
|
940 |
+
# save the predictions to a temporary file
|
941 |
+
random_uuid = str(uuid.uuid4())
|
942 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
943 |
+
f.write(json.dumps(predictions, indent=4))
|
944 |
+
|
945 |
+
if vqa_dataset == "gqa":
|
946 |
+
acc = compute_gqa_accuracy(predictions)
|
947 |
+
else:
|
948 |
+
acc = compute_vqa_accuracy(
|
949 |
+
f"{vqa_dataset}results_{random_uuid}.json",
|
950 |
+
questions_json_path,
|
951 |
+
annotations_json_path,
|
952 |
+
vqa_dataset=vqa_dataset,
|
953 |
+
)
|
954 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
955 |
+
os.makedirs("eval_results", exist_ok=True)
|
956 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
957 |
+
f.write(json.dumps(predictions, indent=2))
|
958 |
+
|
959 |
+
# delete the temporary file
|
960 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
961 |
+
else:
|
962 |
+
time.sleep(5)
|
963 |
+
acc = 0.0
|
964 |
+
if world_size > 1:
|
965 |
+
torch.distributed.barrier()
|
966 |
+
return acc
|
967 |
+
|
968 |
+
|
969 |
+
def evaluate_refcoco(
|
970 |
+
model,
|
971 |
+
tokenizer,
|
972 |
+
image_processor,
|
973 |
+
batch_size,
|
974 |
+
tsvfile,
|
975 |
+
max_generation_length=20,
|
976 |
+
num_beams=3,
|
977 |
+
length_penalty=-2.0,
|
978 |
+
device=-1,
|
979 |
+
vis_embed_size=None,
|
980 |
+
rank=0,
|
981 |
+
world_size=1,
|
982 |
+
id=0,
|
983 |
+
):
|
984 |
+
model.eval().cuda()
|
985 |
+
loc_token_ids = []
|
986 |
+
for i in range(1000):
|
987 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
988 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
989 |
+
total = 0
|
990 |
+
correct = 0
|
991 |
+
ious = []
|
992 |
+
if "refcocog" in tsvfile:
|
993 |
+
dataset_name = "refcocog"
|
994 |
+
elif "refcocoplus" in tsvfile:
|
995 |
+
dataset_name = "refcocoplus"
|
996 |
+
else:
|
997 |
+
dataset_name = "refcoco"
|
998 |
+
with open(tsvfile, "r") as f:
|
999 |
+
lines = f.readlines()
|
1000 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
1001 |
+
for ii, line in enumerate(pbar):
|
1002 |
+
if ii % world_size != rank:
|
1003 |
+
continue
|
1004 |
+
total += 1
|
1005 |
+
line = line.rstrip()
|
1006 |
+
uniq_id, image_id, text, region_coord, image = line.split("\t")
|
1007 |
+
|
1008 |
+
# image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
1009 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
|
1010 |
+
# image2 = Image.open("yolo.png").convert("RGB")
|
1011 |
+
# image1 = image1.resize((224, 224))
|
1012 |
+
# image2 = image2.resize((224, 224))
|
1013 |
+
# images = [image1, image2]
|
1014 |
+
|
1015 |
+
# gt_box = np.array(list(map(float, region_coord.split(","))))
|
1016 |
+
# width = image.width
|
1017 |
+
# height = image.height
|
1018 |
+
# gt_box /= np.array([width, height, width, height])
|
1019 |
+
# batch_images = [image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) for image in images]
|
1020 |
+
# batch_images = torch.cat(batch_images, dim=0)
|
1021 |
+
# image = Image.open("yolo_test.png").convert("RGB")
|
1022 |
+
image = Image.open("example.png").convert("RGB")
|
1023 |
+
image = image.resize((224, 224))
|
1024 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1025 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text.rstrip('.')}<|#visual#|>"]
|
1026 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#endofattr#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
|
1027 |
+
# prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
|
1028 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
|
1029 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
|
1030 |
+
|
1031 |
+
|
1032 |
+
encodings = tokenizer(
|
1033 |
+
prompt,
|
1034 |
+
padding="longest",
|
1035 |
+
truncation=True,
|
1036 |
+
return_tensors="pt",
|
1037 |
+
max_length=2000,
|
1038 |
+
)
|
1039 |
+
input_ids = encodings["input_ids"]
|
1040 |
+
attention_mask = encodings["attention_mask"]
|
1041 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1042 |
+
image_start_index_list = [image_start_index_list]
|
1043 |
+
image_nums = [1]
|
1044 |
+
vision_x = batch_images.cuda()
|
1045 |
+
lang_x = input_ids.cuda()
|
1046 |
+
attention_mask = attention_mask.cuda()
|
1047 |
+
print(image_start_index_list, image_nums)
|
1048 |
+
|
1049 |
+
model.debug_id = 0
|
1050 |
+
# outputs = get_outputs(
|
1051 |
+
# model=model,
|
1052 |
+
# batch_images=vision_x,
|
1053 |
+
# attention_mask=attention_mask,
|
1054 |
+
# max_generation_length=20,
|
1055 |
+
# min_generation_length=8,
|
1056 |
+
# num_beams=5,
|
1057 |
+
# length_penalty=0,
|
1058 |
+
# input_ids=lang_x,
|
1059 |
+
# image_start_index_list=image_start_index_list,
|
1060 |
+
# image_nums=image_nums,
|
1061 |
+
# )
|
1062 |
+
# print(tokenizer.decode(outputs[0]))
|
1063 |
+
# exit()
|
1064 |
+
|
1065 |
+
prebox = [93, 20, 155, 172] # man
|
1066 |
+
# prebox = [32, 82, 89, 213] # dog
|
1067 |
+
# prebox = [34, 49, 166, 164] # bike
|
1068 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
1069 |
+
outputs = model(
|
1070 |
+
vision_x=vision_x,
|
1071 |
+
lang_x=lang_x,
|
1072 |
+
attention_mask=attention_mask,
|
1073 |
+
labels=None,
|
1074 |
+
image_nums=image_nums,
|
1075 |
+
image_start_index_list=image_start_index_list,
|
1076 |
+
added_bbox_list=[torch.tensor(prebox).cuda().unsqueeze(0) / 224],
|
1077 |
+
add_box=True,
|
1078 |
+
debug_mode=True,
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
boxes = outputs["boxes"]
|
1082 |
+
scores = outputs["scores"]
|
1083 |
+
box = boxes[scores.argmax()]
|
1084 |
+
open_cv_image = np.array(image)
|
1085 |
+
# Convert RGB to BGR
|
1086 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1087 |
+
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
1088 |
+
open_cv_image = cv2.rectangle(open_cv_image, prebox[:2], prebox[2:], (0, 0, 255), 2)
|
1089 |
+
cv2.imwrite(f"output2.jpg", open_cv_image)
|
1090 |
+
print(box)
|
1091 |
+
print(prebox)
|
1092 |
+
exit()
|
1093 |
+
|
1094 |
+
# force_words = ["man", "table"]
|
1095 |
+
# force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
1096 |
+
|
1097 |
+
|
1098 |
+
# sequences, hidden_states_for_each_step = get_outputs(
|
1099 |
+
# model=model,
|
1100 |
+
# batch_images=vision_x,
|
1101 |
+
# attention_mask=attention_mask,
|
1102 |
+
# max_generation_length=20,
|
1103 |
+
# min_generation_length=8,
|
1104 |
+
# num_beams=5,
|
1105 |
+
# length_penalty=0,
|
1106 |
+
# input_ids=lang_x,
|
1107 |
+
# image_start_index_list=image_start_index_list,
|
1108 |
+
# image_nums=image_nums,
|
1109 |
+
# force_words_ids=force_words_ids,
|
1110 |
+
# )
|
1111 |
+
# sequence = sequences[0]
|
1112 |
+
# print(tokenizer.decode(sequence))
|
1113 |
+
# for i, token in enumerate(sequence):
|
1114 |
+
# if token == model.visual_token_id:
|
1115 |
+
# print(tokenizer.decode(sequence[:i+1]))
|
1116 |
+
# if hasattr(model, "debug_id"):
|
1117 |
+
# model.debug_id += 1
|
1118 |
+
# else:
|
1119 |
+
# model.debug_id = 0
|
1120 |
+
# this_lang_x = torch.hstack([lang_x[0], sequence[:i+1]]).unsqueeze(0)
|
1121 |
+
# this_attention_mask = torch.ones_like(this_lang_x).cuda()
|
1122 |
+
# with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1123 |
+
# _ = model(
|
1124 |
+
# vision_x=vision_x,
|
1125 |
+
# lang_x=this_lang_x,
|
1126 |
+
# attention_mask=this_attention_mask,
|
1127 |
+
# labels=None,
|
1128 |
+
# image_nums=image_nums,
|
1129 |
+
# image_start_index_list=image_start_index_list,
|
1130 |
+
# added_bbox_list=None,
|
1131 |
+
# )
|
1132 |
+
# exit()
|
1133 |
+
|
1134 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1135 |
+
f.write(json.dumps([total, correct]))
|
1136 |
+
if world_size > 1:
|
1137 |
+
torch.distributed.barrier()
|
1138 |
+
if rank == 0:
|
1139 |
+
total = 0
|
1140 |
+
correct = 0
|
1141 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1142 |
+
for rank_i in range(world_size):
|
1143 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1144 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1145 |
+
total += total_part
|
1146 |
+
correct += correct_part
|
1147 |
+
score = correct / total
|
1148 |
+
print("score:", score)
|
1149 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1150 |
+
pass
|
1151 |
+
else:
|
1152 |
+
score = 0.0
|
1153 |
+
if world_size > 1:
|
1154 |
+
torch.distributed.barrier()
|
1155 |
+
return score
|
1156 |
+
|
1157 |
+
|
1158 |
+
if __name__ == "__main__":
|
1159 |
+
main()
|
multimodal/build/lib/open_flamingo/eval/evaluate_find_showcase.py
ADDED
@@ -0,0 +1,1700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from math import ceil
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable
|
9 |
+
import time
|
10 |
+
import cv2
|
11 |
+
import webdataset as wds
|
12 |
+
from sklearn.metrics import recall_score, average_precision_score
|
13 |
+
|
14 |
+
import more_itertools
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
18 |
+
from eval_datasets import VQADataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from collections import Counter
|
21 |
+
|
22 |
+
from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
|
23 |
+
from open_flamingo.eval.classification import (
|
24 |
+
compute_per_sample_probs,
|
25 |
+
compute_per_sample_loss,
|
26 |
+
)
|
27 |
+
from open_flamingo.eval.imagenet_utils import (
|
28 |
+
openai_imagenet_classnames,
|
29 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL,
|
30 |
+
)
|
31 |
+
|
32 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
33 |
+
from PIL import Image
|
34 |
+
from io import BytesIO
|
35 |
+
import base64
|
36 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
37 |
+
import string
|
38 |
+
from lavis.datasets.builders import load_dataset
|
39 |
+
from open_flamingo.eval.task.reg import evaluate_reg
|
40 |
+
from open_flamingo.eval.task.gqa import GQADataset
|
41 |
+
from open_flamingo.eval.task.vl_checklist import evaluate_vlc
|
42 |
+
from open_flamingo.eval.task.crepe import evaluate_crepe
|
43 |
+
|
44 |
+
def get_iou(box1, box2):
|
45 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
46 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
47 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
48 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
49 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
50 |
+
union = area_box1 + area_box2 - intersection
|
51 |
+
iou = intersection / union if union > 0 else 0
|
52 |
+
return iou
|
53 |
+
|
54 |
+
def expand2square(pil_img, background_color):
|
55 |
+
width, height = pil_img.size
|
56 |
+
if width == height:
|
57 |
+
return pil_img
|
58 |
+
elif width > height:
|
59 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
60 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
61 |
+
return result
|
62 |
+
else:
|
63 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
64 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
65 |
+
return result
|
66 |
+
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
|
69 |
+
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
|
70 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
71 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
72 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
73 |
+
parser.add_argument(
|
74 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
75 |
+
)
|
76 |
+
|
77 |
+
# Trial arguments
|
78 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
79 |
+
parser.add_argument(
|
80 |
+
"--num_trials",
|
81 |
+
type=int,
|
82 |
+
default=1,
|
83 |
+
help="Number of trials to run for each shot using different demonstrations",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--trial_seeds",
|
87 |
+
nargs="+",
|
88 |
+
default=[0],
|
89 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
|
93 |
+
)
|
94 |
+
|
95 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
96 |
+
|
97 |
+
# Per-dataset evaluation flags
|
98 |
+
parser.add_argument(
|
99 |
+
"--eval_coco",
|
100 |
+
action="store_true",
|
101 |
+
default=False,
|
102 |
+
help="Whether to evaluate on COCO.",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--eval_vqav2",
|
106 |
+
action="store_true",
|
107 |
+
default=False,
|
108 |
+
help="Whether to evaluate on VQAV2.",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--eval_ok_vqa",
|
112 |
+
action="store_true",
|
113 |
+
default=False,
|
114 |
+
help="Whether to evaluate on OK-VQA.",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--eval_imagenet",
|
118 |
+
action="store_true",
|
119 |
+
default=False,
|
120 |
+
help="Whether to evaluate on ImageNet.",
|
121 |
+
)
|
122 |
+
|
123 |
+
parser.add_argument(
|
124 |
+
"--eval_flickr30",
|
125 |
+
action="store_true",
|
126 |
+
default=False,
|
127 |
+
help="Whether to evaluate on Flickr30.",
|
128 |
+
)
|
129 |
+
|
130 |
+
parser.add_argument(
|
131 |
+
"--eval_refcoco",
|
132 |
+
action="store_true",
|
133 |
+
default=False,
|
134 |
+
help="Whether to evaluate on RefCOCO.",
|
135 |
+
)
|
136 |
+
|
137 |
+
# Dataset arguments
|
138 |
+
|
139 |
+
## Flickr30 Dataset
|
140 |
+
parser.add_argument(
|
141 |
+
"--flickr_image_dir_path",
|
142 |
+
type=str,
|
143 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
144 |
+
default=None,
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--flickr_annotations_json_path",
|
148 |
+
type=str,
|
149 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
150 |
+
default=None,
|
151 |
+
)
|
152 |
+
|
153 |
+
## COCO Dataset
|
154 |
+
parser.add_argument(
|
155 |
+
"--coco_image_dir_path",
|
156 |
+
type=str,
|
157 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
158 |
+
default=None,
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--coco_annotations_json_path",
|
162 |
+
type=str,
|
163 |
+
default=None,
|
164 |
+
)
|
165 |
+
|
166 |
+
## VQAV2 Dataset
|
167 |
+
parser.add_argument(
|
168 |
+
"--vqav2_image_dir_path",
|
169 |
+
type=str,
|
170 |
+
default=None,
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--vqav2_questions_json_path",
|
174 |
+
type=str,
|
175 |
+
default=None,
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--vqav2_annotations_json_path",
|
179 |
+
type=str,
|
180 |
+
default=None,
|
181 |
+
)
|
182 |
+
|
183 |
+
## OK-VQA Dataset
|
184 |
+
parser.add_argument(
|
185 |
+
"--ok_vqa_image_dir_path",
|
186 |
+
type=str,
|
187 |
+
help="Path to the vqav2/train2014 directory.",
|
188 |
+
default=None,
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--ok_vqa_questions_json_path",
|
192 |
+
type=str,
|
193 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
194 |
+
default=None,
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--ok_vqa_annotations_json_path",
|
198 |
+
type=str,
|
199 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
200 |
+
default=None,
|
201 |
+
)
|
202 |
+
|
203 |
+
## Imagenet dataset
|
204 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
205 |
+
|
206 |
+
## RefCOCO dataset
|
207 |
+
parser.add_argument("--refcoco_tsvfile", type=str, default=None)
|
208 |
+
|
209 |
+
parser.add_argument(
|
210 |
+
"--location_token_num",
|
211 |
+
default=1000,
|
212 |
+
type=int,
|
213 |
+
)
|
214 |
+
# distributed training
|
215 |
+
parser.add_argument(
|
216 |
+
"--dist-url",
|
217 |
+
default="env://",
|
218 |
+
type=str,
|
219 |
+
help="url used to set up distributed training",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--horovod",
|
226 |
+
default=False,
|
227 |
+
action="store_true",
|
228 |
+
help="Use horovod for distributed training.",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--no-set-device-rank",
|
232 |
+
default=False,
|
233 |
+
action="store_true",
|
234 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--dist",
|
238 |
+
default=False,
|
239 |
+
action="store_true",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--lora",
|
243 |
+
default=False,
|
244 |
+
action="store_true",
|
245 |
+
)
|
246 |
+
parser.add_argument(
|
247 |
+
"--lora_r",
|
248 |
+
default=16,
|
249 |
+
type=int,
|
250 |
+
required=False,
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--legacy",
|
254 |
+
default=False,
|
255 |
+
action="store_true",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--special",
|
259 |
+
default=False,
|
260 |
+
action="store_true",
|
261 |
+
)
|
262 |
+
parser.add_argument(
|
263 |
+
"--id",
|
264 |
+
default=0,
|
265 |
+
type=int,
|
266 |
+
required=False,
|
267 |
+
)
|
268 |
+
|
269 |
+
parser.add_argument(
|
270 |
+
"--eval_gqa",
|
271 |
+
default=False,
|
272 |
+
action="store_true",
|
273 |
+
)
|
274 |
+
parser.add_argument(
|
275 |
+
"--use_sam",
|
276 |
+
default=None,
|
277 |
+
type=str,
|
278 |
+
required=False,
|
279 |
+
)
|
280 |
+
parser.add_argument(
|
281 |
+
"--add_visual_token",
|
282 |
+
default=False,
|
283 |
+
action="store_true",
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--use_format_v2",
|
287 |
+
default=False,
|
288 |
+
action="store_true",
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--eval_aro",
|
292 |
+
default=False,
|
293 |
+
action="store_true",
|
294 |
+
)
|
295 |
+
parser.add_argument(
|
296 |
+
"--eval_pisc",
|
297 |
+
default=False,
|
298 |
+
action="store_true",
|
299 |
+
)
|
300 |
+
parser.add_argument(
|
301 |
+
"--eval_reg",
|
302 |
+
default=False,
|
303 |
+
action="store_true",
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--eval_vlc",
|
307 |
+
default=False,
|
308 |
+
action="store_true",
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--eval_crepe",
|
312 |
+
default=False,
|
313 |
+
action="store_true",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--level",
|
317 |
+
default=4,
|
318 |
+
type=int,
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--type",
|
322 |
+
default="swap",
|
323 |
+
type=str,
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
class OKVQAPostProcess():
|
328 |
+
def __init__(self):
|
329 |
+
self._lemmatizer = None
|
330 |
+
|
331 |
+
def _lemmatize(self, answers):
|
332 |
+
def apply(answer):
|
333 |
+
doc = self.lemmatizer(answer)
|
334 |
+
|
335 |
+
words = []
|
336 |
+
for token in doc:
|
337 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
338 |
+
words.append(token.lemma_)
|
339 |
+
else:
|
340 |
+
words.append(token.text)
|
341 |
+
answer = " ".join(words)
|
342 |
+
|
343 |
+
return answer
|
344 |
+
|
345 |
+
return [apply(answer) for answer in answers]
|
346 |
+
|
347 |
+
@property
|
348 |
+
def lemmatizer(self):
|
349 |
+
if self._lemmatizer is None:
|
350 |
+
try:
|
351 |
+
import spacy
|
352 |
+
|
353 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
354 |
+
except ImportError:
|
355 |
+
logging.error(
|
356 |
+
"""
|
357 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
358 |
+
python -m spacy download en_core_web_sm
|
359 |
+
OR
|
360 |
+
import spacy.cli
|
361 |
+
spacy.cli.download("en_core_web_sm")
|
362 |
+
"""
|
363 |
+
)
|
364 |
+
exit(1)
|
365 |
+
|
366 |
+
return self._lemmatizer
|
367 |
+
|
368 |
+
|
369 |
+
def main():
|
370 |
+
args = parser.parse_args()
|
371 |
+
if args.dist:
|
372 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
373 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
374 |
+
device_id = init_distributed_device(args)
|
375 |
+
else:
|
376 |
+
args.rank = 0
|
377 |
+
args.world_size = 1
|
378 |
+
print(f"rank: {args.rank} world_size: {args.world_size}")
|
379 |
+
|
380 |
+
if "sam" in args.checkpoint_path:
|
381 |
+
args.use_sam = "vit_l"
|
382 |
+
|
383 |
+
args.add_visual_token = True
|
384 |
+
if "lora" in args.checkpoint_path:
|
385 |
+
args.lora = True
|
386 |
+
|
387 |
+
|
388 |
+
args.add_pe = False
|
389 |
+
args.add_box = True
|
390 |
+
args.relation = False
|
391 |
+
args.enhance_data = False
|
392 |
+
args.use_format_v2 = True
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
import hashlib
|
397 |
+
args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
|
398 |
+
|
399 |
+
# load model
|
400 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
401 |
+
args.vision_encoder_path,
|
402 |
+
args.vision_encoder_pretrained,
|
403 |
+
args.lm_path,
|
404 |
+
args.lm_tokenizer_path,
|
405 |
+
location_token_num=args.location_token_num,
|
406 |
+
lora=args.lora,
|
407 |
+
lora_r=16,
|
408 |
+
use_sam=args.use_sam,
|
409 |
+
add_visual_token=args.add_visual_token,
|
410 |
+
use_format_v2=args.use_format_v2,
|
411 |
+
add_box=args.add_box,
|
412 |
+
add_pe=args.add_pe,
|
413 |
+
add_relation=args.relation,
|
414 |
+
enhance_data=args.enhance_data,
|
415 |
+
)
|
416 |
+
flamingo.use_format_v2 = args.use_format_v2
|
417 |
+
if args.special:
|
418 |
+
flamingo.special = True
|
419 |
+
else:
|
420 |
+
flamingo.special = False
|
421 |
+
if args.legacy:
|
422 |
+
flamingo.legacy = True
|
423 |
+
print("use legacy evaluation")
|
424 |
+
flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
|
425 |
+
flamingo.expr_name = args.checkpoint_path.split("/")[-2]
|
426 |
+
if args.rank == 0:
|
427 |
+
print("legacy", True if hasattr(flamingo, "legacy") else False)
|
428 |
+
print("step:", flamingo.step_num)
|
429 |
+
print("expr:", flamingo.expr_name)
|
430 |
+
print("use format v2:", flamingo.use_format_v2)
|
431 |
+
print(args)
|
432 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
433 |
+
model_state_dict = {}
|
434 |
+
for key in checkpoint["model_state_dict"].keys():
|
435 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
436 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
437 |
+
# previous checkpoint has some unnecessary weights
|
438 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
439 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
440 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
441 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
442 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
443 |
+
results = defaultdict(list)
|
444 |
+
if args.eval_coco:
|
445 |
+
print("Evaluating on COCO...")
|
446 |
+
for shot in args.shots:
|
447 |
+
scores = []
|
448 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
449 |
+
cider_score = evaluate_coco_flickr(
|
450 |
+
model=flamingo,
|
451 |
+
tokenizer=tokenizer,
|
452 |
+
image_processor=image_processor,
|
453 |
+
batch_size=args.batch_size,
|
454 |
+
image_dir_path=args.coco_image_dir_path,
|
455 |
+
annotations_json_path=args.coco_annotations_json_path,
|
456 |
+
device=args.device,
|
457 |
+
seed=seed,
|
458 |
+
vis_embed_size=vis_embed_size,
|
459 |
+
rank=args.rank,
|
460 |
+
world_size=args.world_size,
|
461 |
+
id=args.id,
|
462 |
+
)
|
463 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
464 |
+
scores.append(cider_score)
|
465 |
+
print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
|
466 |
+
results["coco"].append(
|
467 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
468 |
+
)
|
469 |
+
|
470 |
+
if args.eval_ok_vqa:
|
471 |
+
print("Evaluating on OK-VQA...")
|
472 |
+
for shot in args.shots:
|
473 |
+
scores = []
|
474 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
475 |
+
ok_vqa_score = evaluate_vqa(
|
476 |
+
model=flamingo,
|
477 |
+
tokenizer=tokenizer,
|
478 |
+
image_processor=image_processor,
|
479 |
+
batch_size=args.batch_size,
|
480 |
+
image_dir_path=args.ok_vqa_image_dir_path,
|
481 |
+
questions_json_path=args.ok_vqa_questions_json_path,
|
482 |
+
annotations_json_path=args.ok_vqa_annotations_json_path,
|
483 |
+
vqa_dataset="ok_vqa",
|
484 |
+
vis_embed_size=vis_embed_size,
|
485 |
+
rank=args.rank,
|
486 |
+
world_size=args.world_size,
|
487 |
+
id=args.id,
|
488 |
+
)
|
489 |
+
results["ok_vqa"].append(
|
490 |
+
{"shots": shot, "score": ok_vqa_score}
|
491 |
+
)
|
492 |
+
|
493 |
+
if args.eval_vqav2:
|
494 |
+
print("Evaluating on VQAv2...")
|
495 |
+
for shot in args.shots:
|
496 |
+
scores = []
|
497 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
498 |
+
vqa_score = evaluate_vqa(
|
499 |
+
model=flamingo,
|
500 |
+
tokenizer=tokenizer,
|
501 |
+
image_processor=image_processor,
|
502 |
+
batch_size=args.batch_size,
|
503 |
+
image_dir_path=args.vqav2_image_dir_path,
|
504 |
+
questions_json_path=args.vqav2_questions_json_path,
|
505 |
+
annotations_json_path=args.vqav2_annotations_json_path,
|
506 |
+
vqa_dataset="vqa",
|
507 |
+
vis_embed_size=vis_embed_size,
|
508 |
+
rank=args.rank,
|
509 |
+
world_size=args.world_size,
|
510 |
+
id=args.id,
|
511 |
+
)
|
512 |
+
results["vqav2"].append(
|
513 |
+
{"shots": shot, "score": vqa_score}
|
514 |
+
)
|
515 |
+
|
516 |
+
if args.eval_gqa:
|
517 |
+
print("Evaluating on GQA...")
|
518 |
+
for shot in args.shots:
|
519 |
+
scores = []
|
520 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
521 |
+
vqa_score = evaluate_vqa(
|
522 |
+
model=flamingo,
|
523 |
+
tokenizer=tokenizer,
|
524 |
+
image_processor=image_processor,
|
525 |
+
batch_size=args.batch_size,
|
526 |
+
vqa_dataset="gqa",
|
527 |
+
vis_embed_size=vis_embed_size,
|
528 |
+
rank=args.rank,
|
529 |
+
world_size=args.world_size,
|
530 |
+
id=args.id,
|
531 |
+
)
|
532 |
+
results["gqa"].append(
|
533 |
+
{"shots": shot, "score": vqa_score}
|
534 |
+
)
|
535 |
+
|
536 |
+
if args.eval_refcoco:
|
537 |
+
print("Evaluating on RefCOCO...")
|
538 |
+
refcoco_score = evaluate_refcoco(
|
539 |
+
model=flamingo,
|
540 |
+
tokenizer=tokenizer,
|
541 |
+
image_processor=image_processor,
|
542 |
+
batch_size=args.batch_size,
|
543 |
+
device=args.device,
|
544 |
+
tsvfile=args.refcoco_tsvfile,
|
545 |
+
vis_embed_size=vis_embed_size,
|
546 |
+
rank=args.rank,
|
547 |
+
world_size=args.world_size,
|
548 |
+
id=args.id,
|
549 |
+
)
|
550 |
+
results["refcoco"].append(
|
551 |
+
{"score": refcoco_score}
|
552 |
+
)
|
553 |
+
if args.eval_aro:
|
554 |
+
print("Evaluating on ARO...")
|
555 |
+
aro_score = evaluate_aro(
|
556 |
+
model=flamingo,
|
557 |
+
tokenizer=tokenizer,
|
558 |
+
image_processor=image_processor,
|
559 |
+
batch_size=args.batch_size,
|
560 |
+
device=args.device,
|
561 |
+
tsvfile=args.refcoco_tsvfile,
|
562 |
+
vis_embed_size=vis_embed_size,
|
563 |
+
rank=args.rank,
|
564 |
+
world_size=args.world_size,
|
565 |
+
id=args.id,
|
566 |
+
add_relation=args.relation,
|
567 |
+
)
|
568 |
+
results["aro"].append(
|
569 |
+
{"score": aro_score}
|
570 |
+
)
|
571 |
+
if args.eval_pisc:
|
572 |
+
print("Evaluating on ARO...")
|
573 |
+
aro_score = evaluate_pisc(
|
574 |
+
model=flamingo,
|
575 |
+
tokenizer=tokenizer,
|
576 |
+
image_processor=image_processor,
|
577 |
+
batch_size=args.batch_size,
|
578 |
+
device=args.device,
|
579 |
+
tsvfile=args.refcoco_tsvfile,
|
580 |
+
vis_embed_size=vis_embed_size,
|
581 |
+
rank=args.rank,
|
582 |
+
world_size=args.world_size,
|
583 |
+
id=args.id,
|
584 |
+
)
|
585 |
+
results["pisc"].append(
|
586 |
+
{"score": aro_score}
|
587 |
+
)
|
588 |
+
if args.eval_reg:
|
589 |
+
print("Evaluating on Referring Expression Generation...")
|
590 |
+
cider = evaluate_reg(
|
591 |
+
model=flamingo,
|
592 |
+
tokenizer=tokenizer,
|
593 |
+
image_processor=image_processor,
|
594 |
+
vis_embed_size=vis_embed_size,
|
595 |
+
rank=args.rank,
|
596 |
+
world_size=args.world_size,
|
597 |
+
id=args.id,
|
598 |
+
)
|
599 |
+
results["reg"].append(
|
600 |
+
{"score": cider}
|
601 |
+
)
|
602 |
+
if args.eval_vlc:
|
603 |
+
print("Evaluating on VL-checklist...")
|
604 |
+
vlc_score = evaluate_vlc(
|
605 |
+
model=flamingo,
|
606 |
+
tokenizer=tokenizer,
|
607 |
+
image_processor=image_processor,
|
608 |
+
vis_embed_size=vis_embed_size,
|
609 |
+
rank=args.rank,
|
610 |
+
world_size=args.world_size,
|
611 |
+
id=args.id,
|
612 |
+
)
|
613 |
+
results["vlc"].append(
|
614 |
+
{"score": vlc_score}
|
615 |
+
)
|
616 |
+
if args.eval_crepe:
|
617 |
+
print("Evaluating on CREPE...")
|
618 |
+
crepe_score = evaluate_crepe(
|
619 |
+
model=flamingo,
|
620 |
+
tokenizer=tokenizer,
|
621 |
+
image_processor=image_processor,
|
622 |
+
vis_embed_size=vis_embed_size,
|
623 |
+
rank=args.rank,
|
624 |
+
world_size=args.world_size,
|
625 |
+
id=args.id,
|
626 |
+
level=args.level,
|
627 |
+
type=args.type,
|
628 |
+
)
|
629 |
+
results["crepe"].append(
|
630 |
+
{"score": crepe_score}
|
631 |
+
)
|
632 |
+
|
633 |
+
def prepare_batch_images(batch, image_processor):
|
634 |
+
batch_images = None
|
635 |
+
for b in batch:
|
636 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
637 |
+
if batch_images is None:
|
638 |
+
batch_images = b_image
|
639 |
+
else:
|
640 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
641 |
+
return batch_images
|
642 |
+
|
643 |
+
def get_outputs(
|
644 |
+
model,
|
645 |
+
batch_images,
|
646 |
+
attention_mask,
|
647 |
+
max_generation_length,
|
648 |
+
min_generation_length,
|
649 |
+
num_beams,
|
650 |
+
length_penalty,
|
651 |
+
input_ids,
|
652 |
+
image_start_index_list=None,
|
653 |
+
image_nums=None,
|
654 |
+
bad_words_ids=None,
|
655 |
+
):
|
656 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
657 |
+
outputs = model.generate(
|
658 |
+
batch_images,
|
659 |
+
input_ids,
|
660 |
+
attention_mask=attention_mask,
|
661 |
+
max_new_tokens=max_generation_length,
|
662 |
+
min_length=min_generation_length,
|
663 |
+
num_beams=num_beams,
|
664 |
+
length_penalty=length_penalty,
|
665 |
+
image_start_index_list=image_start_index_list,
|
666 |
+
image_nums=image_nums,
|
667 |
+
bad_words_ids=bad_words_ids,
|
668 |
+
)
|
669 |
+
|
670 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
671 |
+
return outputs
|
672 |
+
|
673 |
+
|
674 |
+
def evaluate_coco_flickr(
|
675 |
+
model,
|
676 |
+
tokenizer,
|
677 |
+
image_processor,
|
678 |
+
batch_size,
|
679 |
+
image_dir_path,
|
680 |
+
annotations_json_path,
|
681 |
+
seed=42,
|
682 |
+
max_generation_length=20,
|
683 |
+
num_beams=1,
|
684 |
+
length_penalty=-2.0,
|
685 |
+
device=-1,
|
686 |
+
is_flickr=False,
|
687 |
+
vis_embed_size=None,
|
688 |
+
rank=0,
|
689 |
+
world_size=1,
|
690 |
+
id=0,
|
691 |
+
):
|
692 |
+
"""Evaluate a model on COCO dataset.
|
693 |
+
|
694 |
+
Args:
|
695 |
+
model (nn.Module): model to evaluate
|
696 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
697 |
+
image_processor : image processor for the model
|
698 |
+
batch_size (int): batch size
|
699 |
+
image_dir_path (str, optional): path to the directory containing the images.
|
700 |
+
annotations_json_path (str, optional): path to the json file containing the annotations.
|
701 |
+
seed (int, optional): seed for random number generator. Defaults to 42.
|
702 |
+
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
|
703 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
704 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
705 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
|
706 |
+
query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
|
707 |
+
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
|
708 |
+
device (int, optional): device to use. Defaults to -1.
|
709 |
+
num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
|
710 |
+
is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
|
711 |
+
|
712 |
+
Returns:
|
713 |
+
float: CIDEr score
|
714 |
+
|
715 |
+
"""
|
716 |
+
# eval_dataset = COCOFlickrDataset(
|
717 |
+
# image_dir_path=image_dir_path,
|
718 |
+
# annotations_path=annotations_json_path,
|
719 |
+
# is_flickr=is_flickr,
|
720 |
+
# )
|
721 |
+
coco_dataset = load_dataset("coco_caption")
|
722 |
+
eval_dataset = coco_dataset["test"]
|
723 |
+
|
724 |
+
|
725 |
+
model.eval().cuda()
|
726 |
+
predictions = defaultdict()
|
727 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
728 |
+
# if "peft" in lang_encoder_name:
|
729 |
+
# lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
730 |
+
try:
|
731 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
732 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
733 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
734 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
735 |
+
except:
|
736 |
+
pass
|
737 |
+
|
738 |
+
def get_prompt(sample):
|
739 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
740 |
+
|
741 |
+
tokenizer.padding_side = "left"
|
742 |
+
cnt = 0
|
743 |
+
if world_size > 1:
|
744 |
+
torch.distributed.barrier()
|
745 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
746 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
747 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
748 |
+
)):
|
749 |
+
if ii % world_size != rank:
|
750 |
+
continue
|
751 |
+
cnt += len(batch)
|
752 |
+
batch_images = prepare_batch_images(
|
753 |
+
batch=batch,
|
754 |
+
image_processor=image_processor,
|
755 |
+
).cuda()
|
756 |
+
batch_text = [get_prompt(s) for s in batch]
|
757 |
+
encodings = tokenizer(
|
758 |
+
batch_text,
|
759 |
+
padding="longest",
|
760 |
+
truncation=True,
|
761 |
+
return_tensors="pt",
|
762 |
+
max_length=2000,
|
763 |
+
)
|
764 |
+
input_ids = encodings["input_ids"].cuda()
|
765 |
+
attention_mask = encodings["attention_mask"].cuda()
|
766 |
+
skip_special_tokens = False
|
767 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
768 |
+
if rank == 0:
|
769 |
+
tqdm.write("use legacy model")
|
770 |
+
skip_special_tokens = True
|
771 |
+
for i in range(len(input_ids)):
|
772 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
773 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
774 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
775 |
+
input_ids[i, media_token_index] = pad_token_id
|
776 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
777 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
778 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
779 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
780 |
+
image_nums = [1] * len(input_ids)
|
781 |
+
if "llama" in lang_encoder_name:
|
782 |
+
attention_mask[input_ids == 0] = 0
|
783 |
+
outputs = get_outputs(
|
784 |
+
model=model,
|
785 |
+
batch_images=batch_images,
|
786 |
+
attention_mask=attention_mask,
|
787 |
+
max_generation_length=30,
|
788 |
+
min_generation_length=8,
|
789 |
+
num_beams=5,
|
790 |
+
length_penalty=0,
|
791 |
+
input_ids=input_ids,
|
792 |
+
image_start_index_list=image_start_index_list,
|
793 |
+
image_nums=image_nums,
|
794 |
+
)
|
795 |
+
new_predictions = [
|
796 |
+
postprocess_captioning_generation(out).replace('"', "")
|
797 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
798 |
+
]
|
799 |
+
# if rank == 0:
|
800 |
+
# tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
|
801 |
+
|
802 |
+
for i, sample in enumerate(batch):
|
803 |
+
predictions[int(sample["image_id"])] = {
|
804 |
+
"caption": new_predictions[i],
|
805 |
+
}
|
806 |
+
results_path = (
|
807 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
808 |
+
if is_flickr
|
809 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
810 |
+
)
|
811 |
+
with open(results_path, "w") as f:
|
812 |
+
f.write(
|
813 |
+
json.dumps(
|
814 |
+
[
|
815 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
816 |
+
for k in predictions
|
817 |
+
],
|
818 |
+
indent=2,
|
819 |
+
)
|
820 |
+
)
|
821 |
+
print("save to", results_path)
|
822 |
+
del predictions
|
823 |
+
time.sleep(10)
|
824 |
+
if world_size > 1:
|
825 |
+
torch.distributed.barrier()
|
826 |
+
if rank == 0:
|
827 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
828 |
+
predictions = []
|
829 |
+
for rank_i in range(world_size):
|
830 |
+
part_results_path = (
|
831 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
832 |
+
if is_flickr
|
833 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
834 |
+
)
|
835 |
+
print("load", part_results_path)
|
836 |
+
predictions.extend(json.load(open(part_results_path)))
|
837 |
+
os.remove(part_results_path)
|
838 |
+
print("num:", len(predictions))
|
839 |
+
results_path = (
|
840 |
+
f"flickrresults_{lang_encoder_name}.json"
|
841 |
+
if is_flickr
|
842 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
843 |
+
)
|
844 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
845 |
+
|
846 |
+
metrics = compute_cider(
|
847 |
+
result_path=results_path,
|
848 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
849 |
+
)
|
850 |
+
os.makedirs("eval_results", exist_ok=True)
|
851 |
+
acc = metrics["CIDEr"]
|
852 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
853 |
+
f.write(json.dumps(predictions, indent=2))
|
854 |
+
|
855 |
+
# delete the temporary file
|
856 |
+
os.remove(results_path)
|
857 |
+
else:
|
858 |
+
metrics = {}
|
859 |
+
metrics["CIDEr"] = 0.0
|
860 |
+
|
861 |
+
return metrics["CIDEr"]
|
862 |
+
|
863 |
+
|
864 |
+
def evaluate_vqa(
|
865 |
+
model,
|
866 |
+
tokenizer,
|
867 |
+
image_processor,
|
868 |
+
batch_size,
|
869 |
+
image_dir_path=None,
|
870 |
+
questions_json_path=None,
|
871 |
+
annotations_json_path=None,
|
872 |
+
vqa_dataset="vqa",
|
873 |
+
vis_embed_size=None,
|
874 |
+
rank=0,
|
875 |
+
world_size=1,
|
876 |
+
id=0,
|
877 |
+
):
|
878 |
+
"""
|
879 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
880 |
+
|
881 |
+
Args:
|
882 |
+
model (nn.Module): model to evaluate
|
883 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
884 |
+
image_processor : image processor for the model
|
885 |
+
batch_size (int): batch size
|
886 |
+
image_dir_path (str): path to image directory
|
887 |
+
questions_json_path (str): path to questions json file
|
888 |
+
annotations_json_path (str): path to annotations json file
|
889 |
+
seed (int, optional): random seed. Defaults to 42.
|
890 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
891 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
892 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
893 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
894 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
895 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
896 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
897 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
898 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
899 |
+
Returns:
|
900 |
+
float: accuracy score
|
901 |
+
"""
|
902 |
+
if world_size > 1:
|
903 |
+
torch.distributed.barrier()
|
904 |
+
if vqa_dataset == "gqa":
|
905 |
+
eval_dataset = GQADataset()
|
906 |
+
else:
|
907 |
+
eval_dataset = VQADataset(
|
908 |
+
image_dir_path=image_dir_path,
|
909 |
+
question_path=questions_json_path,
|
910 |
+
annotations_path=annotations_json_path,
|
911 |
+
vqa_dataset=vqa_dataset,
|
912 |
+
)
|
913 |
+
postprocessor = OKVQAPostProcess()
|
914 |
+
try:
|
915 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
916 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
917 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
918 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
919 |
+
except:
|
920 |
+
pass
|
921 |
+
def get_prompt(sample):
|
922 |
+
return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
923 |
+
# return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
924 |
+
|
925 |
+
model.eval().cuda()
|
926 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
927 |
+
if "peft" in lang_encoder_name:
|
928 |
+
lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
929 |
+
predictions = []
|
930 |
+
tokenizer.padding_side = "left"
|
931 |
+
if world_size > 1:
|
932 |
+
torch.distributed.barrier()
|
933 |
+
this_tot = 0
|
934 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
935 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
|
936 |
+
)):
|
937 |
+
if ii % world_size != rank:
|
938 |
+
continue
|
939 |
+
batch_images = prepare_batch_images(
|
940 |
+
batch=batch,
|
941 |
+
image_processor=image_processor,
|
942 |
+
).cuda()
|
943 |
+
batch_text = [get_prompt(s) for s in batch]
|
944 |
+
encodings = tokenizer(
|
945 |
+
batch_text,
|
946 |
+
return_tensors="pt",
|
947 |
+
padding="longest",
|
948 |
+
truncation=True,
|
949 |
+
max_length=2000,
|
950 |
+
)
|
951 |
+
input_ids = encodings["input_ids"].cuda()
|
952 |
+
attention_mask = encodings["attention_mask"].cuda()
|
953 |
+
skip_special_tokens = True
|
954 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
955 |
+
if rank == 0:
|
956 |
+
tqdm.write("use legacy model")
|
957 |
+
for i in range(len(input_ids)):
|
958 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
959 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
960 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
961 |
+
input_ids[i, media_token_index] = pad_token_id
|
962 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
963 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
964 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
965 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
966 |
+
image_nums = [1] * len(input_ids)
|
967 |
+
if "llama" in lang_encoder_name:
|
968 |
+
attention_mask[input_ids == 0] = 0
|
969 |
+
outputs = get_outputs(
|
970 |
+
model=model,
|
971 |
+
batch_images=batch_images,
|
972 |
+
attention_mask=attention_mask,
|
973 |
+
max_generation_length=10,
|
974 |
+
min_generation_length=1,
|
975 |
+
num_beams=5,
|
976 |
+
length_penalty=0,
|
977 |
+
input_ids=input_ids,
|
978 |
+
image_start_index_list=image_start_index_list,
|
979 |
+
image_nums=image_nums,
|
980 |
+
)
|
981 |
+
# postprocess begin
|
982 |
+
new_predictions = [
|
983 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
984 |
+
]
|
985 |
+
if vqa_dataset == "ok_vqa":
|
986 |
+
new_predictions = postprocessor._lemmatize(new_predictions)
|
987 |
+
if model.special:
|
988 |
+
for i in range(len(new_predictions)):
|
989 |
+
for answer, _ in Counter(batch[i]['answers']).most_common():
|
990 |
+
if answer in new_predictions[i]:
|
991 |
+
new_predictions[i] = answer
|
992 |
+
break
|
993 |
+
if "cant" in new_predictions[i] and "no" == answer:
|
994 |
+
new_predictions[i] = answer
|
995 |
+
break
|
996 |
+
if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
|
997 |
+
new_predictions[i] = answer
|
998 |
+
break
|
999 |
+
|
1000 |
+
this_tot += 1
|
1001 |
+
if rank == 0 and this_tot % 20 == 0:
|
1002 |
+
for i in range(1):
|
1003 |
+
tqdm.write("model output: " + new_predictions[i])
|
1004 |
+
|
1005 |
+
predictions.extend(
|
1006 |
+
[
|
1007 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
1008 |
+
for p, sample in zip(new_predictions, batch)
|
1009 |
+
]
|
1010 |
+
)
|
1011 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
1012 |
+
f.write(json.dumps(predictions))
|
1013 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
1014 |
+
|
1015 |
+
time.sleep(10)
|
1016 |
+
if world_size > 1:
|
1017 |
+
torch.distributed.barrier()
|
1018 |
+
if rank == 0:
|
1019 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1020 |
+
predictions = []
|
1021 |
+
for rank_i in range(world_size):
|
1022 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
1023 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
1024 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
1025 |
+
print("num:", len(predictions))
|
1026 |
+
# save the predictions to a temporary file
|
1027 |
+
random_uuid = str(uuid.uuid4())
|
1028 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
1029 |
+
f.write(json.dumps(predictions, indent=4))
|
1030 |
+
|
1031 |
+
if vqa_dataset == "gqa":
|
1032 |
+
acc = compute_gqa_accuracy(predictions)
|
1033 |
+
else:
|
1034 |
+
acc = compute_vqa_accuracy(
|
1035 |
+
f"{vqa_dataset}results_{random_uuid}.json",
|
1036 |
+
questions_json_path,
|
1037 |
+
annotations_json_path,
|
1038 |
+
vqa_dataset=vqa_dataset,
|
1039 |
+
)
|
1040 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
1041 |
+
os.makedirs("eval_results", exist_ok=True)
|
1042 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
1043 |
+
f.write(json.dumps(predictions, indent=2))
|
1044 |
+
|
1045 |
+
# delete the temporary file
|
1046 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
1047 |
+
else:
|
1048 |
+
time.sleep(5)
|
1049 |
+
acc = 0.0
|
1050 |
+
if world_size > 1:
|
1051 |
+
torch.distributed.barrier()
|
1052 |
+
return acc
|
1053 |
+
|
1054 |
+
|
1055 |
+
def evaluate_refcoco(
|
1056 |
+
model,
|
1057 |
+
tokenizer,
|
1058 |
+
image_processor,
|
1059 |
+
batch_size,
|
1060 |
+
tsvfile,
|
1061 |
+
max_generation_length=20,
|
1062 |
+
num_beams=3,
|
1063 |
+
length_penalty=-2.0,
|
1064 |
+
device=-1,
|
1065 |
+
vis_embed_size=None,
|
1066 |
+
rank=0,
|
1067 |
+
world_size=1,
|
1068 |
+
id=0,
|
1069 |
+
):
|
1070 |
+
model.eval().cuda()
|
1071 |
+
loc_token_ids = []
|
1072 |
+
for i in range(1000):
|
1073 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
1074 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1075 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1076 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
1077 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
1078 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
1079 |
+
# all_ids = set(range(model.lang_encoder.lm_head.out_features))
|
1080 |
+
# bad_words_ids = list(all_ids - set(loc_token_ids))
|
1081 |
+
# bad_words_ids = [[b] for b in bad_words_ids]
|
1082 |
+
# min_loc_token_id = min(loc_token_ids)
|
1083 |
+
# max_loc_token_id = max(loc_token_ids)
|
1084 |
+
total = 0
|
1085 |
+
correct = 0
|
1086 |
+
ious = []
|
1087 |
+
if "refcocog" in tsvfile:
|
1088 |
+
dataset_name = "refcocog"
|
1089 |
+
elif "refcocoplus" in tsvfile:
|
1090 |
+
dataset_name = "refcocoplus"
|
1091 |
+
else:
|
1092 |
+
dataset_name = "refcoco"
|
1093 |
+
with open(tsvfile, "r") as f:
|
1094 |
+
lines = f.readlines()
|
1095 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
1096 |
+
for ii, line in enumerate(pbar):
|
1097 |
+
if ii % world_size != rank:
|
1098 |
+
continue
|
1099 |
+
total += 1
|
1100 |
+
line = line.rstrip()
|
1101 |
+
uniq_id, image_id, text, region_coord, image = line.split("\t")
|
1102 |
+
|
1103 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
1104 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
|
1105 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
|
1106 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
|
1107 |
+
|
1108 |
+
gt_box = np.array(list(map(float, region_coord.split(","))))
|
1109 |
+
width = image.width
|
1110 |
+
height = image.height
|
1111 |
+
image = image.resize((224, 224))
|
1112 |
+
gt_box = gt_box / np.array([width, height, width, height]) * 224
|
1113 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1114 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
|
1115 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
|
1116 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
|
1117 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
|
1118 |
+
|
1119 |
+
|
1120 |
+
encodings = tokenizer(
|
1121 |
+
prompt,
|
1122 |
+
padding="longest",
|
1123 |
+
truncation=True,
|
1124 |
+
return_tensors="pt",
|
1125 |
+
max_length=2000,
|
1126 |
+
)
|
1127 |
+
input_ids = encodings["input_ids"]
|
1128 |
+
attention_mask = encodings["attention_mask"]
|
1129 |
+
# attention_mask[input_ids == prebox_token_id] = 0
|
1130 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1131 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1132 |
+
image_nums = [1] * len(input_ids)
|
1133 |
+
vision_x = batch_images.cuda()
|
1134 |
+
lang_x = input_ids.cuda()
|
1135 |
+
attention_mask = attention_mask.cuda()
|
1136 |
+
|
1137 |
+
model.debug_id = 0
|
1138 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
1139 |
+
outputs = model(
|
1140 |
+
vision_x=vision_x,
|
1141 |
+
lang_x=lang_x,
|
1142 |
+
attention_mask=attention_mask,
|
1143 |
+
labels=None,
|
1144 |
+
image_nums=image_nums,
|
1145 |
+
image_start_index_list=image_start_index_list,
|
1146 |
+
added_bbox_list=None,
|
1147 |
+
add_box=False,
|
1148 |
+
)
|
1149 |
+
boxes = outputs["boxes"]
|
1150 |
+
scores = outputs["scores"]
|
1151 |
+
if len(scores) > 0:
|
1152 |
+
box = boxes[scores.argmax()]
|
1153 |
+
iou = get_iou(box, gt_box)
|
1154 |
+
else:
|
1155 |
+
iou = 0.0
|
1156 |
+
# tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
|
1157 |
+
tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
|
1158 |
+
if iou >= 0.5:
|
1159 |
+
correct += 1
|
1160 |
+
pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
|
1161 |
+
# open_cv_image = np.array(image)
|
1162 |
+
# # Convert RGB to BGR
|
1163 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1164 |
+
# for box, score in zip(boxes, scores):
|
1165 |
+
# open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
1166 |
+
# cv2.imwrite("output.jpg", open_cv_image)
|
1167 |
+
# print(boxes)
|
1168 |
+
# print(scores)
|
1169 |
+
# exit()
|
1170 |
+
|
1171 |
+
|
1172 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1173 |
+
f.write(json.dumps([total, correct]))
|
1174 |
+
if world_size > 1:
|
1175 |
+
torch.distributed.barrier()
|
1176 |
+
if rank == 0:
|
1177 |
+
total = 0
|
1178 |
+
correct = 0
|
1179 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1180 |
+
for rank_i in range(world_size):
|
1181 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1182 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1183 |
+
total += total_part
|
1184 |
+
correct += correct_part
|
1185 |
+
score = correct / total
|
1186 |
+
print("score:", score)
|
1187 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1188 |
+
pass
|
1189 |
+
else:
|
1190 |
+
score = 0.0
|
1191 |
+
if world_size > 1:
|
1192 |
+
torch.distributed.barrier()
|
1193 |
+
return score
|
1194 |
+
|
1195 |
+
|
1196 |
+
|
1197 |
+
# def preprocess_visual_info(Text):
|
1198 |
+
# text = Text.split(" ")
|
1199 |
+
# for is_idx, t in enumerate(text):
|
1200 |
+
# if t == "is":
|
1201 |
+
# break
|
1202 |
+
# the_idx = is_idx
|
1203 |
+
# while text[the_idx] != "the":
|
1204 |
+
# the_idx -= 1
|
1205 |
+
# obj_A = " ".join(text[the_idx+1:is_idx])
|
1206 |
+
# second_the_idx = len(text) - 1
|
1207 |
+
# while text[second_the_idx] != "the":
|
1208 |
+
# second_the_idx -= 1
|
1209 |
+
# obj_B = " ".join(text[second_the_idx+1:])
|
1210 |
+
# visual_obj_A = f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
|
1211 |
+
# visual_obj_B = f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
|
1212 |
+
# Text = Text.replace(obj_A, f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
|
1213 |
+
# Text = Text.replace(obj_B, f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
|
1214 |
+
# return Text, obj_A, obj_B, visual_obj_A, visual_obj_B
|
1215 |
+
|
1216 |
+
|
1217 |
+
def preprocess_visual_info(Text):
|
1218 |
+
text = Text.split(" ")
|
1219 |
+
for is_idx, t in enumerate(text):
|
1220 |
+
if t == "is":
|
1221 |
+
break
|
1222 |
+
the_idx = is_idx
|
1223 |
+
while text[the_idx] != "the":
|
1224 |
+
the_idx -= 1
|
1225 |
+
obj_A = " ".join(text[the_idx+1:is_idx])
|
1226 |
+
second_the_idx = len(text) - 1
|
1227 |
+
while text[second_the_idx] != "the":
|
1228 |
+
second_the_idx -= 1
|
1229 |
+
obj_B = " ".join(text[second_the_idx+1:])
|
1230 |
+
relation = " ".join(text[is_idx+1:second_the_idx])
|
1231 |
+
visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
|
1232 |
+
visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
|
1233 |
+
Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
|
1234 |
+
return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
|
1235 |
+
|
1236 |
+
|
1237 |
+
|
1238 |
+
|
1239 |
+
def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
|
1240 |
+
assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
|
1241 |
+
encodings = tokenizer(
|
1242 |
+
prompt,
|
1243 |
+
padding="longest",
|
1244 |
+
truncation=True,
|
1245 |
+
return_tensors="pt",
|
1246 |
+
max_length=2000,
|
1247 |
+
)
|
1248 |
+
input_ids = encodings["input_ids"]
|
1249 |
+
attention_mask = encodings["attention_mask"]
|
1250 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1251 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1252 |
+
image_nums = [1] * len(input_ids)
|
1253 |
+
vision_x = batch_images.cuda()
|
1254 |
+
lang_x = input_ids.cuda()
|
1255 |
+
attention_mask = attention_mask.cuda()
|
1256 |
+
|
1257 |
+
model.debug_id = 0
|
1258 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
1259 |
+
outputs = model(
|
1260 |
+
vision_x=vision_x,
|
1261 |
+
lang_x=lang_x,
|
1262 |
+
attention_mask=attention_mask,
|
1263 |
+
labels=None,
|
1264 |
+
image_nums=image_nums,
|
1265 |
+
image_start_index_list=image_start_index_list,
|
1266 |
+
added_bbox_list=visual_box_list,
|
1267 |
+
add_box=visual_box_list is not None,
|
1268 |
+
relations=None,
|
1269 |
+
debug_mode=False,
|
1270 |
+
)
|
1271 |
+
boxes = outputs["boxes"]
|
1272 |
+
scores = outputs["scores"]
|
1273 |
+
if debug:
|
1274 |
+
import pdb; pdb.set_trace()
|
1275 |
+
if return_all:
|
1276 |
+
return boxes, scores
|
1277 |
+
if len(scores) == 0:
|
1278 |
+
return None, None
|
1279 |
+
else:
|
1280 |
+
return boxes[scores.argmax()], scores.max()
|
1281 |
+
|
1282 |
+
|
1283 |
+
def evaluate_aro(
|
1284 |
+
model,
|
1285 |
+
tokenizer,
|
1286 |
+
image_processor,
|
1287 |
+
batch_size,
|
1288 |
+
tsvfile,
|
1289 |
+
max_generation_length=20,
|
1290 |
+
num_beams=3,
|
1291 |
+
length_penalty=-2.0,
|
1292 |
+
device=-1,
|
1293 |
+
vis_embed_size=None,
|
1294 |
+
rank=0,
|
1295 |
+
world_size=1,
|
1296 |
+
id=0,
|
1297 |
+
add_visual=True,
|
1298 |
+
add_relation=False,
|
1299 |
+
subset=False,
|
1300 |
+
choose_left_right=True,
|
1301 |
+
):
|
1302 |
+
both_failed_ids = json.load(open("both_failed_ids.json"))
|
1303 |
+
os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
|
1304 |
+
# from groundingdino.demo.caption_grounder import caption_grounder
|
1305 |
+
# generator = caption_grounder(
|
1306 |
+
# config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
1307 |
+
# checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
1308 |
+
# cpu_only=False,
|
1309 |
+
# box_threshold=0.1, text_threshold=0.1,
|
1310 |
+
# )
|
1311 |
+
dataset_name = "aro"
|
1312 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1313 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1314 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1315 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1316 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1317 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1318 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
1319 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
1320 |
+
model.eval().cuda()
|
1321 |
+
total = 0
|
1322 |
+
correct = 0
|
1323 |
+
from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
|
1324 |
+
vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
|
1325 |
+
with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f:
|
1326 |
+
all_labels = json.load(f)
|
1327 |
+
label_ids = tokenizer(all_labels).input_ids
|
1328 |
+
label_ids = sorted(list(set([x[0] for x in label_ids])))
|
1329 |
+
|
1330 |
+
if subset:
|
1331 |
+
subset_idx = json.load(open("aro_subset.json"))
|
1332 |
+
pbar = tqdm(subset_idx, disable=(rank != 0))
|
1333 |
+
else:
|
1334 |
+
pbar = tqdm(vgr_dataset, disable=(rank != 0))
|
1335 |
+
for ii, sample in enumerate(pbar):
|
1336 |
+
if subset:
|
1337 |
+
ORI_IDX = int(sample)
|
1338 |
+
sample = vgr_dataset[sample]
|
1339 |
+
# if ORI_IDX != 19036:
|
1340 |
+
# continue
|
1341 |
+
if ii % world_size != rank:
|
1342 |
+
continue
|
1343 |
+
|
1344 |
+
# not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
|
1345 |
+
# if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
|
1346 |
+
# if rank == 0:
|
1347 |
+
# tqdm.write(f"SKIP: {sample['caption_options'][1]}")
|
1348 |
+
# continue
|
1349 |
+
total += 1
|
1350 |
+
# image = sample["image_options"][0]
|
1351 |
+
image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/man_on_hydrant.png").convert("RGB")
|
1352 |
+
image = image.resize((224, 224))
|
1353 |
+
|
1354 |
+
# text = sample["caption_options"][1] # 1 is true caption
|
1355 |
+
text = "the man is sitting on the fire hydrant"
|
1356 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1357 |
+
text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
|
1358 |
+
|
1359 |
+
|
1360 |
+
first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
|
1361 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
|
1362 |
+
first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
|
1363 |
+
|
1364 |
+
|
1365 |
+
# use grounding DINO to get the first bbox
|
1366 |
+
# caption = f"{obj_A}"
|
1367 |
+
# with torch.no_grad():
|
1368 |
+
# logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
|
1369 |
+
# boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
|
1370 |
+
# objects = {}
|
1371 |
+
# for box, phrase in zip(boxes_filt, pred_phrases):
|
1372 |
+
# obj, score = phrase
|
1373 |
+
# obj = obj[0]
|
1374 |
+
# if obj not in objects:
|
1375 |
+
# objects[obj] = (score, box)
|
1376 |
+
# if objects[obj][0] < score:
|
1377 |
+
# objects[obj] = (score, box)
|
1378 |
+
# try:
|
1379 |
+
# first_box = objects[obj_A][1].clone()
|
1380 |
+
# first_box[:2] -= first_box[2:] / 2
|
1381 |
+
# first_box[2:] += first_box[:2]
|
1382 |
+
# first_box = first_box.clamp(0, 0.99) * 224.0
|
1383 |
+
# first_box = first_box.numpy()
|
1384 |
+
# first_score = objects[obj_A][0]
|
1385 |
+
# except:
|
1386 |
+
# first_box = None
|
1387 |
+
|
1388 |
+
if first_box is None:
|
1389 |
+
text_A = "the " + obj_A
|
1390 |
+
added_bbox_list = None
|
1391 |
+
else:
|
1392 |
+
text_A = visual_obj_A
|
1393 |
+
added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
1394 |
+
|
1395 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
|
1396 |
+
pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
|
1397 |
+
prebox_token_id, return_all=True)
|
1398 |
+
|
1399 |
+
|
1400 |
+
# open_cv_image = np.array(image)
|
1401 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1402 |
+
# for box, score in zip(pre_box, pre_score):
|
1403 |
+
# print(box, score)
|
1404 |
+
# if score > 0.1:
|
1405 |
+
# open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2)
|
1406 |
+
# cv2.imwrite(f"test1.jpg", open_cv_image)
|
1407 |
+
# print(sample["caption_options"][idx])
|
1408 |
+
# exit()
|
1409 |
+
|
1410 |
+
|
1411 |
+
|
1412 |
+
if pre_boxes is None:
|
1413 |
+
pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
|
1414 |
+
pre_scores = [1.0]
|
1415 |
+
|
1416 |
+
rank_list = []
|
1417 |
+
# pre_boxes = [pre_boxes[0]]
|
1418 |
+
# pre_scores = [pre_scores[0]]
|
1419 |
+
for pre_box, pre_score in zip(pre_boxes, pre_scores):
|
1420 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
|
1421 |
+
|
1422 |
+
encodings = tokenizer(
|
1423 |
+
prompt,
|
1424 |
+
padding="longest",
|
1425 |
+
truncation=True,
|
1426 |
+
return_tensors="pt",
|
1427 |
+
max_length=512,
|
1428 |
+
)
|
1429 |
+
input_ids = encodings["input_ids"]
|
1430 |
+
attention_mask = encodings["attention_mask"]
|
1431 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1432 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1433 |
+
image_nums = [1] * len(input_ids)
|
1434 |
+
vision_x = batch_images.cuda()
|
1435 |
+
lang_x = input_ids.cuda()
|
1436 |
+
attention_mask = attention_mask.cuda()
|
1437 |
+
labels = lang_x.clone()
|
1438 |
+
|
1439 |
+
answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1
|
1440 |
+
# pre_box = None
|
1441 |
+
labels[0, :answer_start_idx] = -100
|
1442 |
+
# # labels[labels == endofobject_token_id] = -100
|
1443 |
+
# labels[:, 0] = -100
|
1444 |
+
# labels[labels == visual_token_id] = -100
|
1445 |
+
# labels[labels == box_token_id] = -100
|
1446 |
+
# labels[labels == previsual_token_id] = -100
|
1447 |
+
# labels[labels == prebox_token_id] = -100
|
1448 |
+
# labels[labels == endofattr_token_id] = -100
|
1449 |
+
# labels[labels == tokenizer.pad_token_id] = -100
|
1450 |
+
# labels[labels == media_token_id] = -100
|
1451 |
+
# labels[labels == endofmedia_token_id] = -100
|
1452 |
+
answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"]
|
1453 |
+
labels[input_ids == visual_token_id] = -100
|
1454 |
+
labels[input_ids == box_token_id] = -100
|
1455 |
+
labels[input_ids == endofattr_token_id] = -100
|
1456 |
+
labels[input_ids == previsual_token_id] = -100
|
1457 |
+
labels[input_ids == prebox_token_id] = -100
|
1458 |
+
labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
|
1459 |
+
labels[torch.roll(input_ids == box_token_id, 1)] = -100
|
1460 |
+
labels[:, 0] = -100
|
1461 |
+
labels[input_ids == tokenizer.pad_token_id] = -100
|
1462 |
+
labels[input_ids == media_token_id] = -100
|
1463 |
+
labels[input_ids == endofmedia_token_id] = -100
|
1464 |
+
|
1465 |
+
added_bbox_list = None
|
1466 |
+
if add_visual:
|
1467 |
+
added_bbox_list = []
|
1468 |
+
if first_box is not None:
|
1469 |
+
added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
|
1470 |
+
if pre_box is not None:
|
1471 |
+
added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
|
1472 |
+
if added_bbox_list is not None and len(added_bbox_list) == 0:
|
1473 |
+
added_bbox_list = None
|
1474 |
+
|
1475 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1476 |
+
outputs = model(
|
1477 |
+
vision_x=vision_x,
|
1478 |
+
lang_x=lang_x,
|
1479 |
+
attention_mask=attention_mask,
|
1480 |
+
labels=labels,
|
1481 |
+
image_nums=image_nums,
|
1482 |
+
image_start_index_list=image_start_index_list,
|
1483 |
+
added_bbox_list=added_bbox_list,
|
1484 |
+
add_box=added_bbox_list is not None,
|
1485 |
+
relations=None,
|
1486 |
+
)
|
1487 |
+
logits = outputs["logits"][0, answer_start_idx:]
|
1488 |
+
# _rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0]))
|
1489 |
+
_rank = logits[0].sort(descending=True).indices.tolist().index(answer_ids[0])
|
1490 |
+
print(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10]))
|
1491 |
+
print(tokenizer.decode(logits[1].sort(descending=True).indices.tolist()[:10]))
|
1492 |
+
rank_list.append(_rank)
|
1493 |
+
# open_cv_image = np.array(image)
|
1494 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1495 |
+
# if first_box is not None:
|
1496 |
+
# open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
|
1497 |
+
# if pre_box is not None:
|
1498 |
+
# open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
|
1499 |
+
|
1500 |
+
# font = cv2.FONT_HERSHEY_SIMPLEX
|
1501 |
+
# org = [10, 20]
|
1502 |
+
# fontScale = 0.5
|
1503 |
+
# color = (0, 0, 0)
|
1504 |
+
# thickness = 1
|
1505 |
+
# open_cv_image = cv2.resize(open_cv_image, (512, 512))
|
1506 |
+
# put_text = sample["caption_options"][1]
|
1507 |
+
# open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1508 |
+
# org[1] += 20
|
1509 |
+
# put_text = "top10 in green box"
|
1510 |
+
# open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1511 |
+
# fontScale = 1.0
|
1512 |
+
# thickness = 2
|
1513 |
+
# for ind in logits_list[i][0].sort(descending=True).indices[:10]:
|
1514 |
+
# org[1] += 20
|
1515 |
+
# put_text = f"{tokenizer.decode(ind)}"
|
1516 |
+
# open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1517 |
+
# tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}")
|
1518 |
+
# tqdm.write(f"{rank_list}")
|
1519 |
+
final_rank = min(rank_list)
|
1520 |
+
if final_rank < 10:
|
1521 |
+
correct += 1
|
1522 |
+
TYPE = "CORRECT"
|
1523 |
+
# if ii in both_failed_ids:
|
1524 |
+
# tqdm.write(f"case find->{sample['caption_options'][1]}")
|
1525 |
+
# image.save(f"case_study/{ii}_{rank_list}_{sample['caption_options'][1]}.jpg")
|
1526 |
+
if rank == 0:
|
1527 |
+
tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
|
1528 |
+
else:
|
1529 |
+
TYPE = "WRONG"
|
1530 |
+
if rank == 0:
|
1531 |
+
tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
|
1532 |
+
# cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image)
|
1533 |
+
pbar.set_description(f"score: {correct / total:.4f} | {final_rank}")
|
1534 |
+
|
1535 |
+
|
1536 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1537 |
+
f.write(json.dumps([total, correct]))
|
1538 |
+
if world_size > 1:
|
1539 |
+
torch.distributed.barrier()
|
1540 |
+
if rank == 0:
|
1541 |
+
total = 0
|
1542 |
+
correct = 0
|
1543 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1544 |
+
for rank_i in range(world_size):
|
1545 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1546 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1547 |
+
total += total_part
|
1548 |
+
correct += correct_part
|
1549 |
+
score = correct / total
|
1550 |
+
print("score:", score, "total:", total)
|
1551 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1552 |
+
pass
|
1553 |
+
else:
|
1554 |
+
score = 0.0
|
1555 |
+
if world_size > 1:
|
1556 |
+
torch.distributed.barrier()
|
1557 |
+
return score
|
1558 |
+
|
1559 |
+
|
1560 |
+
def evaluate_pisc(
|
1561 |
+
model,
|
1562 |
+
tokenizer,
|
1563 |
+
image_processor,
|
1564 |
+
batch_size,
|
1565 |
+
tsvfile,
|
1566 |
+
max_generation_length=20,
|
1567 |
+
num_beams=3,
|
1568 |
+
length_penalty=-2.0,
|
1569 |
+
device=-1,
|
1570 |
+
vis_embed_size=None,
|
1571 |
+
rank=0,
|
1572 |
+
world_size=1,
|
1573 |
+
id=0,
|
1574 |
+
add_visual=True,
|
1575 |
+
):
|
1576 |
+
from open_flamingo.train.instruction_template import PISC_TEMPLATES
|
1577 |
+
dataset_name = "pisc"
|
1578 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1579 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1580 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1581 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1582 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1583 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1584 |
+
model.train().cuda()
|
1585 |
+
|
1586 |
+
dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
|
1587 |
+
pbar = tqdm(dataset, disable=(rank != 0))
|
1588 |
+
|
1589 |
+
rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
|
1590 |
+
rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
|
1591 |
+
gt = []
|
1592 |
+
pred_scores = []
|
1593 |
+
for III, sample in enumerate(pbar):
|
1594 |
+
if III % world_size != rank:
|
1595 |
+
continue
|
1596 |
+
image_path, dataset, data = sample
|
1597 |
+
image = Image.open(image_path)
|
1598 |
+
size = image_processor.transforms[0].size
|
1599 |
+
image = image.resize((size, size))
|
1600 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1601 |
+
boxA = data[0]
|
1602 |
+
boxB = data[1]
|
1603 |
+
gt_relation = data[2]
|
1604 |
+
losses = []
|
1605 |
+
for i_rel, option_rel in enumerate(rel_id_to_type):
|
1606 |
+
text = PISC_TEMPLATES[0].format(relation=option_rel)
|
1607 |
+
added_bbox = [
|
1608 |
+
torch.tensor([boxA]).cuda(),
|
1609 |
+
torch.tensor([boxB]).cuda(),
|
1610 |
+
]
|
1611 |
+
caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
|
1612 |
+
encodings = tokenizer(
|
1613 |
+
caption,
|
1614 |
+
padding="longest",
|
1615 |
+
truncation=True,
|
1616 |
+
return_tensors="pt",
|
1617 |
+
max_length=2000,
|
1618 |
+
)
|
1619 |
+
input_ids = encodings["input_ids"]
|
1620 |
+
attention_mask = encodings["attention_mask"]
|
1621 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1622 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1623 |
+
image_nums = [1] * len(input_ids)
|
1624 |
+
vision_x = batch_images.cuda()
|
1625 |
+
lang_x = input_ids.cuda()
|
1626 |
+
attention_mask = attention_mask.cuda()
|
1627 |
+
|
1628 |
+
labels = lang_x.clone()
|
1629 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
1630 |
+
if add_visual:
|
1631 |
+
# endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
|
1632 |
+
# endofattr_next_token_index[1] += 1
|
1633 |
+
# endofattr_next_token_id = labels[endofattr_next_token_index]
|
1634 |
+
# </obj><visual><box></attr>NEXT_WORD
|
1635 |
+
# </obj> predict NEXT_WORD
|
1636 |
+
# <visual><box></attr> predict nothing
|
1637 |
+
labels[labels == visual_token_id] = -100
|
1638 |
+
labels[labels == box_token_id] = -100
|
1639 |
+
labels[labels == endofattr_token_id] = -100
|
1640 |
+
# labels[endofattr_next_token_index] = -100
|
1641 |
+
labels[:, 0] = -100
|
1642 |
+
answer_token_id = tokenizer(" Answer").input_ids[0]
|
1643 |
+
answer_token_loc = (input_ids == answer_token_id).nonzero()
|
1644 |
+
for batch_idx, idx in answer_token_loc:
|
1645 |
+
labels[batch_idx][:idx+2] = -100
|
1646 |
+
|
1647 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1648 |
+
outputs = model(
|
1649 |
+
vision_x=vision_x,
|
1650 |
+
lang_x=lang_x,
|
1651 |
+
attention_mask=attention_mask,
|
1652 |
+
labels=labels,
|
1653 |
+
image_nums=image_nums,
|
1654 |
+
image_start_index_list=image_start_index_list,
|
1655 |
+
added_bbox_list=added_bbox,
|
1656 |
+
add_box=added_bbox is not None,
|
1657 |
+
)
|
1658 |
+
loss_total = outputs.loss.reshape(labels.shape[0], -1)
|
1659 |
+
loss = loss_total.sum() / (loss_total != 0).sum()
|
1660 |
+
losses.append(loss.item())
|
1661 |
+
pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
|
1662 |
+
gt.append(rel_type_to_id[gt_relation])
|
1663 |
+
gt = np.array(gt)
|
1664 |
+
pred_scores = np.array(pred_scores)
|
1665 |
+
pred = pred_scores.argmax(1)
|
1666 |
+
|
1667 |
+
|
1668 |
+
print("total num:", len(gt))
|
1669 |
+
recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
|
1670 |
+
print("recalls:", recalls)
|
1671 |
+
|
1672 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1673 |
+
f.write(json.dumps([gt.tolist(), pred.tolist()]))
|
1674 |
+
if world_size > 1:
|
1675 |
+
torch.distributed.barrier()
|
1676 |
+
if rank == 0:
|
1677 |
+
gt = []
|
1678 |
+
pred = []
|
1679 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1680 |
+
for rank_i in range(world_size):
|
1681 |
+
[gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1682 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1683 |
+
gt.extend(gt_part)
|
1684 |
+
pred.extend(pred_part)
|
1685 |
+
print("total num:", len(gt))
|
1686 |
+
recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
|
1687 |
+
print("recalls:", recalls)
|
1688 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
|
1689 |
+
f.write(f"{gt}\n")
|
1690 |
+
f.write(f"{pred}\n")
|
1691 |
+
f.write(f"{recalls}\n")
|
1692 |
+
score = 0.0
|
1693 |
+
if world_size > 1:
|
1694 |
+
torch.distributed.barrier()
|
1695 |
+
return score
|
1696 |
+
|
1697 |
+
|
1698 |
+
|
1699 |
+
if __name__ == "__main__":
|
1700 |
+
main()
|
multimodal/build/lib/open_flamingo/eval/evaluate_temp.py
ADDED
@@ -0,0 +1,1838 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from math import ceil
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable
|
9 |
+
import time
|
10 |
+
import cv2
|
11 |
+
import webdataset as wds
|
12 |
+
from sklearn.metrics import recall_score, average_precision_score
|
13 |
+
|
14 |
+
import more_itertools
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
18 |
+
from eval_datasets import VQADataset, GQADataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from collections import Counter
|
21 |
+
|
22 |
+
from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
|
23 |
+
from open_flamingo.eval.classification import (
|
24 |
+
compute_per_sample_probs,
|
25 |
+
compute_per_sample_loss,
|
26 |
+
)
|
27 |
+
from open_flamingo.eval.imagenet_utils import (
|
28 |
+
openai_imagenet_classnames,
|
29 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL,
|
30 |
+
)
|
31 |
+
|
32 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
33 |
+
from PIL import Image
|
34 |
+
from io import BytesIO
|
35 |
+
import base64
|
36 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
37 |
+
import string
|
38 |
+
from lavis.datasets.builders import load_dataset
|
39 |
+
|
40 |
+
|
41 |
+
def get_iou(box1, box2):
|
42 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
43 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
44 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
45 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
46 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
47 |
+
union = area_box1 + area_box2 - intersection
|
48 |
+
iou = intersection / union if union > 0 else 0
|
49 |
+
return iou
|
50 |
+
|
51 |
+
def expand2square(pil_img, background_color):
|
52 |
+
width, height = pil_img.size
|
53 |
+
if width == height:
|
54 |
+
return pil_img
|
55 |
+
elif width > height:
|
56 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
57 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
58 |
+
return result
|
59 |
+
else:
|
60 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
61 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
62 |
+
return result
|
63 |
+
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
|
66 |
+
parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
|
67 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
68 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
69 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
70 |
+
parser.add_argument(
|
71 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
72 |
+
)
|
73 |
+
|
74 |
+
# Trial arguments
|
75 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
76 |
+
parser.add_argument(
|
77 |
+
"--num_trials",
|
78 |
+
type=int,
|
79 |
+
default=1,
|
80 |
+
help="Number of trials to run for each shot using different demonstrations",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--trial_seeds",
|
84 |
+
nargs="+",
|
85 |
+
default=[0],
|
86 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
|
90 |
+
)
|
91 |
+
|
92 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
93 |
+
|
94 |
+
# Per-dataset evaluation flags
|
95 |
+
parser.add_argument(
|
96 |
+
"--eval_coco",
|
97 |
+
action="store_true",
|
98 |
+
default=False,
|
99 |
+
help="Whether to evaluate on COCO.",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--eval_vqav2",
|
103 |
+
action="store_true",
|
104 |
+
default=False,
|
105 |
+
help="Whether to evaluate on VQAV2.",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--eval_ok_vqa",
|
109 |
+
action="store_true",
|
110 |
+
default=False,
|
111 |
+
help="Whether to evaluate on OK-VQA.",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--eval_imagenet",
|
115 |
+
action="store_true",
|
116 |
+
default=False,
|
117 |
+
help="Whether to evaluate on ImageNet.",
|
118 |
+
)
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--eval_flickr30",
|
122 |
+
action="store_true",
|
123 |
+
default=False,
|
124 |
+
help="Whether to evaluate on Flickr30.",
|
125 |
+
)
|
126 |
+
|
127 |
+
parser.add_argument(
|
128 |
+
"--eval_refcoco",
|
129 |
+
action="store_true",
|
130 |
+
default=False,
|
131 |
+
help="Whether to evaluate on RefCOCO.",
|
132 |
+
)
|
133 |
+
|
134 |
+
# Dataset arguments
|
135 |
+
|
136 |
+
## Flickr30 Dataset
|
137 |
+
parser.add_argument(
|
138 |
+
"--flickr_image_dir_path",
|
139 |
+
type=str,
|
140 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
141 |
+
default=None,
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--flickr_annotations_json_path",
|
145 |
+
type=str,
|
146 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
147 |
+
default=None,
|
148 |
+
)
|
149 |
+
|
150 |
+
## COCO Dataset
|
151 |
+
parser.add_argument(
|
152 |
+
"--coco_image_dir_path",
|
153 |
+
type=str,
|
154 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
155 |
+
default=None,
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--coco_annotations_json_path",
|
159 |
+
type=str,
|
160 |
+
default=None,
|
161 |
+
)
|
162 |
+
|
163 |
+
## VQAV2 Dataset
|
164 |
+
parser.add_argument(
|
165 |
+
"--vqav2_image_dir_path",
|
166 |
+
type=str,
|
167 |
+
default=None,
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--vqav2_questions_json_path",
|
171 |
+
type=str,
|
172 |
+
default=None,
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--vqav2_annotations_json_path",
|
176 |
+
type=str,
|
177 |
+
default=None,
|
178 |
+
)
|
179 |
+
|
180 |
+
## OK-VQA Dataset
|
181 |
+
parser.add_argument(
|
182 |
+
"--ok_vqa_image_dir_path",
|
183 |
+
type=str,
|
184 |
+
help="Path to the vqav2/train2014 directory.",
|
185 |
+
default=None,
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--ok_vqa_questions_json_path",
|
189 |
+
type=str,
|
190 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
191 |
+
default=None,
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--ok_vqa_annotations_json_path",
|
195 |
+
type=str,
|
196 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
197 |
+
default=None,
|
198 |
+
)
|
199 |
+
|
200 |
+
## Imagenet dataset
|
201 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
202 |
+
|
203 |
+
## RefCOCO dataset
|
204 |
+
parser.add_argument("--refcoco_tsvfile", type=str, default=None)
|
205 |
+
|
206 |
+
parser.add_argument(
|
207 |
+
"--location_token_num",
|
208 |
+
default=1000,
|
209 |
+
type=int,
|
210 |
+
)
|
211 |
+
# distributed training
|
212 |
+
parser.add_argument(
|
213 |
+
"--dist-url",
|
214 |
+
default="env://",
|
215 |
+
type=str,
|
216 |
+
help="url used to set up distributed training",
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--horovod",
|
223 |
+
default=False,
|
224 |
+
action="store_true",
|
225 |
+
help="Use horovod for distributed training.",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--no-set-device-rank",
|
229 |
+
default=False,
|
230 |
+
action="store_true",
|
231 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--dist",
|
235 |
+
default=False,
|
236 |
+
action="store_true",
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--lora",
|
240 |
+
default=False,
|
241 |
+
action="store_true",
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--lora_r",
|
245 |
+
default=16,
|
246 |
+
type=int,
|
247 |
+
required=False,
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--legacy",
|
251 |
+
default=False,
|
252 |
+
action="store_true",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--special",
|
256 |
+
default=False,
|
257 |
+
action="store_true",
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--id",
|
261 |
+
default=0,
|
262 |
+
type=int,
|
263 |
+
required=False,
|
264 |
+
)
|
265 |
+
|
266 |
+
parser.add_argument(
|
267 |
+
"--eval_gqa",
|
268 |
+
default=False,
|
269 |
+
action="store_true",
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
"--use_sam",
|
273 |
+
default=None,
|
274 |
+
type=str,
|
275 |
+
required=False,
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--add_visual_token",
|
279 |
+
default=False,
|
280 |
+
action="store_true",
|
281 |
+
)
|
282 |
+
parser.add_argument(
|
283 |
+
"--use_format_v2",
|
284 |
+
default=False,
|
285 |
+
action="store_true",
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--eval_aro",
|
289 |
+
default=False,
|
290 |
+
action="store_true",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--eval_pisc",
|
294 |
+
default=False,
|
295 |
+
action="store_true",
|
296 |
+
)
|
297 |
+
|
298 |
+
|
299 |
+
class OKVQAPostProcess():
|
300 |
+
def __init__(self):
|
301 |
+
self._lemmatizer = None
|
302 |
+
|
303 |
+
def _lemmatize(self, answers):
|
304 |
+
def apply(answer):
|
305 |
+
doc = self.lemmatizer(answer)
|
306 |
+
|
307 |
+
words = []
|
308 |
+
for token in doc:
|
309 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
310 |
+
words.append(token.lemma_)
|
311 |
+
else:
|
312 |
+
words.append(token.text)
|
313 |
+
answer = " ".join(words)
|
314 |
+
|
315 |
+
return answer
|
316 |
+
|
317 |
+
return [apply(answer) for answer in answers]
|
318 |
+
|
319 |
+
@property
|
320 |
+
def lemmatizer(self):
|
321 |
+
if self._lemmatizer is None:
|
322 |
+
try:
|
323 |
+
import spacy
|
324 |
+
|
325 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
326 |
+
except ImportError:
|
327 |
+
logging.error(
|
328 |
+
"""
|
329 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
330 |
+
python -m spacy download en_core_web_sm
|
331 |
+
OR
|
332 |
+
import spacy.cli
|
333 |
+
spacy.cli.download("en_core_web_sm")
|
334 |
+
"""
|
335 |
+
)
|
336 |
+
exit(1)
|
337 |
+
|
338 |
+
return self._lemmatizer
|
339 |
+
|
340 |
+
|
341 |
+
def main():
|
342 |
+
args = parser.parse_args()
|
343 |
+
if args.dist:
|
344 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
345 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
346 |
+
device_id = init_distributed_device(args)
|
347 |
+
else:
|
348 |
+
args.rank = 0
|
349 |
+
args.world_size = 1
|
350 |
+
print(f"rank: {args.rank} world_size: {args.world_size}")
|
351 |
+
|
352 |
+
if "sam" in args.checkpoint_path:
|
353 |
+
args.use_sam = "vit_l"
|
354 |
+
|
355 |
+
args.add_visual_token = True
|
356 |
+
if "lora" in args.checkpoint_path:
|
357 |
+
args.lora = True
|
358 |
+
|
359 |
+
|
360 |
+
args.add_pe = False
|
361 |
+
args.add_box = True
|
362 |
+
args.relation = False
|
363 |
+
args.enhance_data = False
|
364 |
+
args.use_format_v2 = True
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
import hashlib
|
369 |
+
args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
|
370 |
+
|
371 |
+
# load model
|
372 |
+
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
|
373 |
+
args.vision_encoder_path,
|
374 |
+
args.vision_encoder_pretrained,
|
375 |
+
args.lm_path,
|
376 |
+
args.lm_tokenizer_path,
|
377 |
+
location_token_num=args.location_token_num,
|
378 |
+
lora=args.lora,
|
379 |
+
lora_r=16,
|
380 |
+
use_sam=args.use_sam,
|
381 |
+
add_visual_token=args.add_visual_token,
|
382 |
+
use_format_v2=args.use_format_v2,
|
383 |
+
add_box=args.add_box,
|
384 |
+
add_pe=args.add_pe,
|
385 |
+
add_relation=args.relation,
|
386 |
+
enhance_data=args.enhance_data,
|
387 |
+
)
|
388 |
+
flamingo.use_format_v2 = args.use_format_v2
|
389 |
+
if args.special:
|
390 |
+
flamingo.special = True
|
391 |
+
else:
|
392 |
+
flamingo.special = False
|
393 |
+
if args.legacy:
|
394 |
+
flamingo.legacy = True
|
395 |
+
print("use legacy evaluation")
|
396 |
+
flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
|
397 |
+
flamingo.expr_name = args.checkpoint_path.split("/")[-2]
|
398 |
+
if args.rank == 0:
|
399 |
+
print("legacy", True if hasattr(flamingo, "legacy") else False)
|
400 |
+
print("step:", flamingo.step_num)
|
401 |
+
print("expr:", flamingo.expr_name)
|
402 |
+
print("use format v2:", flamingo.use_format_v2)
|
403 |
+
print(args)
|
404 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
405 |
+
model_state_dict = {}
|
406 |
+
for key in checkpoint["model_state_dict"].keys():
|
407 |
+
model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
|
408 |
+
if "vision_encoder.logit_scale"in model_state_dict:
|
409 |
+
# previous checkpoint has some unnecessary weights
|
410 |
+
del model_state_dict["vision_encoder.logit_scale"]
|
411 |
+
del model_state_dict["vision_encoder.visual.proj"]
|
412 |
+
del model_state_dict["vision_encoder.visual.ln_post.weight"]
|
413 |
+
del model_state_dict["vision_encoder.visual.ln_post.bias"]
|
414 |
+
flamingo.load_state_dict(model_state_dict, strict=True)
|
415 |
+
results = defaultdict(list)
|
416 |
+
if args.eval_coco:
|
417 |
+
print("Evaluating on COCO...")
|
418 |
+
for shot in args.shots:
|
419 |
+
scores = []
|
420 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
421 |
+
cider_score = evaluate_coco_flickr(
|
422 |
+
model=flamingo,
|
423 |
+
tokenizer=tokenizer,
|
424 |
+
image_processor=image_processor,
|
425 |
+
batch_size=args.batch_size,
|
426 |
+
image_dir_path=args.coco_image_dir_path,
|
427 |
+
annotations_json_path=args.coco_annotations_json_path,
|
428 |
+
device=args.device,
|
429 |
+
seed=seed,
|
430 |
+
vis_embed_size=vis_embed_size,
|
431 |
+
rank=args.rank,
|
432 |
+
world_size=args.world_size,
|
433 |
+
id=args.id,
|
434 |
+
)
|
435 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
436 |
+
scores.append(cider_score)
|
437 |
+
print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
|
438 |
+
results["coco"].append(
|
439 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
440 |
+
)
|
441 |
+
|
442 |
+
if args.eval_ok_vqa:
|
443 |
+
print("Evaluating on OK-VQA...")
|
444 |
+
for shot in args.shots:
|
445 |
+
scores = []
|
446 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
447 |
+
ok_vqa_score = evaluate_vqa(
|
448 |
+
model=flamingo,
|
449 |
+
tokenizer=tokenizer,
|
450 |
+
image_processor=image_processor,
|
451 |
+
batch_size=args.batch_size,
|
452 |
+
image_dir_path=args.ok_vqa_image_dir_path,
|
453 |
+
questions_json_path=args.ok_vqa_questions_json_path,
|
454 |
+
annotations_json_path=args.ok_vqa_annotations_json_path,
|
455 |
+
vqa_dataset="ok_vqa",
|
456 |
+
vis_embed_size=vis_embed_size,
|
457 |
+
rank=args.rank,
|
458 |
+
world_size=args.world_size,
|
459 |
+
id=args.id,
|
460 |
+
)
|
461 |
+
results["ok_vqa"].append(
|
462 |
+
{"shots": shot, "score": ok_vqa_score}
|
463 |
+
)
|
464 |
+
|
465 |
+
if args.eval_vqav2:
|
466 |
+
print("Evaluating on VQAv2...")
|
467 |
+
for shot in args.shots:
|
468 |
+
scores = []
|
469 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
470 |
+
vqa_score = evaluate_vqa(
|
471 |
+
model=flamingo,
|
472 |
+
tokenizer=tokenizer,
|
473 |
+
image_processor=image_processor,
|
474 |
+
batch_size=args.batch_size,
|
475 |
+
image_dir_path=args.vqav2_image_dir_path,
|
476 |
+
questions_json_path=args.vqav2_questions_json_path,
|
477 |
+
annotations_json_path=args.vqav2_annotations_json_path,
|
478 |
+
vqa_dataset="vqa",
|
479 |
+
vis_embed_size=vis_embed_size,
|
480 |
+
rank=args.rank,
|
481 |
+
world_size=args.world_size,
|
482 |
+
id=args.id,
|
483 |
+
)
|
484 |
+
results["vqav2"].append(
|
485 |
+
{"shots": shot, "score": vqa_score}
|
486 |
+
)
|
487 |
+
|
488 |
+
if args.eval_gqa:
|
489 |
+
print("Evaluating on GQA...")
|
490 |
+
for shot in args.shots:
|
491 |
+
scores = []
|
492 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
493 |
+
vqa_score = evaluate_vqa(
|
494 |
+
model=flamingo,
|
495 |
+
tokenizer=tokenizer,
|
496 |
+
image_processor=image_processor,
|
497 |
+
batch_size=args.batch_size,
|
498 |
+
vqa_dataset="gqa",
|
499 |
+
vis_embed_size=vis_embed_size,
|
500 |
+
rank=args.rank,
|
501 |
+
world_size=args.world_size,
|
502 |
+
id=args.id,
|
503 |
+
)
|
504 |
+
results["gqa"].append(
|
505 |
+
{"shots": shot, "score": vqa_score}
|
506 |
+
)
|
507 |
+
|
508 |
+
if args.eval_imagenet:
|
509 |
+
print("Evaluating on ImageNet...")
|
510 |
+
for shot in args.shots:
|
511 |
+
scores = []
|
512 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
513 |
+
imagenet_score = evaluate_imagenet(
|
514 |
+
model=flamingo,
|
515 |
+
tokenizer=tokenizer,
|
516 |
+
image_processor=image_processor,
|
517 |
+
batch_size=args.batch_size,
|
518 |
+
num_samples=args.num_samples,
|
519 |
+
num_shots=shot,
|
520 |
+
device=args.device,
|
521 |
+
seed=seed,
|
522 |
+
imagenet_root=args.imagenet_root,
|
523 |
+
)
|
524 |
+
print(
|
525 |
+
f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
|
526 |
+
)
|
527 |
+
scores.append(imagenet_score)
|
528 |
+
print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
|
529 |
+
results["imagenet"].append(
|
530 |
+
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
|
531 |
+
)
|
532 |
+
|
533 |
+
if args.eval_refcoco:
|
534 |
+
print("Evaluating on RefCOCO...")
|
535 |
+
refcoco_score = evaluate_refcoco(
|
536 |
+
model=flamingo,
|
537 |
+
tokenizer=tokenizer,
|
538 |
+
image_processor=image_processor,
|
539 |
+
batch_size=args.batch_size,
|
540 |
+
device=args.device,
|
541 |
+
tsvfile=args.refcoco_tsvfile,
|
542 |
+
vis_embed_size=vis_embed_size,
|
543 |
+
rank=args.rank,
|
544 |
+
world_size=args.world_size,
|
545 |
+
id=args.id,
|
546 |
+
)
|
547 |
+
results["refcoco"].append(
|
548 |
+
{"score": refcoco_score}
|
549 |
+
)
|
550 |
+
if args.eval_aro:
|
551 |
+
print("Evaluating on ARO...")
|
552 |
+
_func = evaluate_aro
|
553 |
+
# print("Evaluating on ARO ORI...")
|
554 |
+
# _func = evaluate_aro_ori
|
555 |
+
aro_score = _func(
|
556 |
+
model=flamingo,
|
557 |
+
tokenizer=tokenizer,
|
558 |
+
image_processor=image_processor,
|
559 |
+
batch_size=args.batch_size,
|
560 |
+
device=args.device,
|
561 |
+
tsvfile=args.refcoco_tsvfile,
|
562 |
+
vis_embed_size=vis_embed_size,
|
563 |
+
rank=args.rank,
|
564 |
+
world_size=args.world_size,
|
565 |
+
id=args.id,
|
566 |
+
add_relation=args.relation,
|
567 |
+
)
|
568 |
+
results["aro"].append(
|
569 |
+
{"score": aro_score}
|
570 |
+
)
|
571 |
+
if args.eval_pisc:
|
572 |
+
print("Evaluating on ARO...")
|
573 |
+
aro_score = evaluate_pisc(
|
574 |
+
model=flamingo,
|
575 |
+
tokenizer=tokenizer,
|
576 |
+
image_processor=image_processor,
|
577 |
+
batch_size=args.batch_size,
|
578 |
+
device=args.device,
|
579 |
+
tsvfile=args.refcoco_tsvfile,
|
580 |
+
vis_embed_size=vis_embed_size,
|
581 |
+
rank=args.rank,
|
582 |
+
world_size=args.world_size,
|
583 |
+
id=args.id,
|
584 |
+
)
|
585 |
+
results["pisc"].append(
|
586 |
+
{"score": aro_score}
|
587 |
+
)
|
588 |
+
|
589 |
+
def prepare_batch_images(batch, image_processor):
|
590 |
+
batch_images = None
|
591 |
+
for b in batch:
|
592 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
593 |
+
if batch_images is None:
|
594 |
+
batch_images = b_image
|
595 |
+
else:
|
596 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
597 |
+
return batch_images
|
598 |
+
|
599 |
+
def get_outputs(
|
600 |
+
model,
|
601 |
+
batch_images,
|
602 |
+
attention_mask,
|
603 |
+
max_generation_length,
|
604 |
+
min_generation_length,
|
605 |
+
num_beams,
|
606 |
+
length_penalty,
|
607 |
+
input_ids,
|
608 |
+
image_start_index_list=None,
|
609 |
+
image_nums=None,
|
610 |
+
bad_words_ids=None,
|
611 |
+
):
|
612 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
613 |
+
outputs = model.generate(
|
614 |
+
batch_images,
|
615 |
+
input_ids,
|
616 |
+
attention_mask=attention_mask,
|
617 |
+
max_new_tokens=max_generation_length,
|
618 |
+
min_length=min_generation_length,
|
619 |
+
num_beams=num_beams,
|
620 |
+
length_penalty=length_penalty,
|
621 |
+
image_start_index_list=image_start_index_list,
|
622 |
+
image_nums=image_nums,
|
623 |
+
bad_words_ids=bad_words_ids,
|
624 |
+
)
|
625 |
+
|
626 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
627 |
+
return outputs
|
628 |
+
|
629 |
+
|
630 |
+
def evaluate_coco_flickr(
|
631 |
+
model,
|
632 |
+
tokenizer,
|
633 |
+
image_processor,
|
634 |
+
batch_size,
|
635 |
+
image_dir_path,
|
636 |
+
annotations_json_path,
|
637 |
+
seed=42,
|
638 |
+
max_generation_length=20,
|
639 |
+
num_beams=1,
|
640 |
+
length_penalty=-2.0,
|
641 |
+
device=-1,
|
642 |
+
is_flickr=False,
|
643 |
+
vis_embed_size=None,
|
644 |
+
rank=0,
|
645 |
+
world_size=1,
|
646 |
+
id=0,
|
647 |
+
):
|
648 |
+
"""Evaluate a model on COCO dataset.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
model (nn.Module): model to evaluate
|
652 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
653 |
+
image_processor : image processor for the model
|
654 |
+
batch_size (int): batch size
|
655 |
+
image_dir_path (str, optional): path to the directory containing the images.
|
656 |
+
annotations_json_path (str, optional): path to the json file containing the annotations.
|
657 |
+
seed (int, optional): seed for random number generator. Defaults to 42.
|
658 |
+
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
|
659 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
660 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
661 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
|
662 |
+
query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
|
663 |
+
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
|
664 |
+
device (int, optional): device to use. Defaults to -1.
|
665 |
+
num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
|
666 |
+
is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
|
667 |
+
|
668 |
+
Returns:
|
669 |
+
float: CIDEr score
|
670 |
+
|
671 |
+
"""
|
672 |
+
# eval_dataset = COCOFlickrDataset(
|
673 |
+
# image_dir_path=image_dir_path,
|
674 |
+
# annotations_path=annotations_json_path,
|
675 |
+
# is_flickr=is_flickr,
|
676 |
+
# )
|
677 |
+
coco_dataset = load_dataset("coco_caption")
|
678 |
+
eval_dataset = coco_dataset["test"]
|
679 |
+
|
680 |
+
|
681 |
+
model.eval().cuda()
|
682 |
+
predictions = defaultdict()
|
683 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
684 |
+
# if "peft" in lang_encoder_name:
|
685 |
+
# lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
686 |
+
try:
|
687 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
688 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
689 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
690 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
691 |
+
except:
|
692 |
+
pass
|
693 |
+
|
694 |
+
def get_prompt(sample):
|
695 |
+
return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
696 |
+
|
697 |
+
tokenizer.padding_side = "left"
|
698 |
+
cnt = 0
|
699 |
+
if world_size > 1:
|
700 |
+
torch.distributed.barrier()
|
701 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
702 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
703 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
704 |
+
)):
|
705 |
+
if ii % world_size != rank:
|
706 |
+
continue
|
707 |
+
cnt += len(batch)
|
708 |
+
batch_images = prepare_batch_images(
|
709 |
+
batch=batch,
|
710 |
+
image_processor=image_processor,
|
711 |
+
).cuda()
|
712 |
+
batch_text = [get_prompt(s) for s in batch]
|
713 |
+
encodings = tokenizer(
|
714 |
+
batch_text,
|
715 |
+
padding="longest",
|
716 |
+
truncation=True,
|
717 |
+
return_tensors="pt",
|
718 |
+
max_length=2000,
|
719 |
+
)
|
720 |
+
input_ids = encodings["input_ids"].cuda()
|
721 |
+
attention_mask = encodings["attention_mask"].cuda()
|
722 |
+
skip_special_tokens = False
|
723 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
724 |
+
if rank == 0:
|
725 |
+
tqdm.write("use legacy model")
|
726 |
+
skip_special_tokens = True
|
727 |
+
for i in range(len(input_ids)):
|
728 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
729 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
730 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
731 |
+
input_ids[i, media_token_index] = pad_token_id
|
732 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
733 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
734 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
735 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
736 |
+
image_nums = [1] * len(input_ids)
|
737 |
+
if "llama" in lang_encoder_name:
|
738 |
+
attention_mask[input_ids == 0] = 0
|
739 |
+
outputs = get_outputs(
|
740 |
+
model=model,
|
741 |
+
batch_images=batch_images,
|
742 |
+
attention_mask=attention_mask,
|
743 |
+
max_generation_length=30,
|
744 |
+
min_generation_length=8,
|
745 |
+
num_beams=5,
|
746 |
+
length_penalty=0,
|
747 |
+
input_ids=input_ids,
|
748 |
+
image_start_index_list=image_start_index_list,
|
749 |
+
image_nums=image_nums,
|
750 |
+
)
|
751 |
+
new_predictions = [
|
752 |
+
postprocess_captioning_generation(out).replace('"', "")
|
753 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
754 |
+
]
|
755 |
+
# if rank == 0:
|
756 |
+
# tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
|
757 |
+
|
758 |
+
for i, sample in enumerate(batch):
|
759 |
+
predictions[int(sample["image_id"])] = {
|
760 |
+
"caption": new_predictions[i],
|
761 |
+
}
|
762 |
+
results_path = (
|
763 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
764 |
+
if is_flickr
|
765 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
766 |
+
)
|
767 |
+
with open(results_path, "w") as f:
|
768 |
+
f.write(
|
769 |
+
json.dumps(
|
770 |
+
[
|
771 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
772 |
+
for k in predictions
|
773 |
+
],
|
774 |
+
indent=2,
|
775 |
+
)
|
776 |
+
)
|
777 |
+
print("save to", results_path)
|
778 |
+
del predictions
|
779 |
+
time.sleep(10)
|
780 |
+
if world_size > 1:
|
781 |
+
torch.distributed.barrier()
|
782 |
+
if rank == 0:
|
783 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
784 |
+
predictions = []
|
785 |
+
for rank_i in range(world_size):
|
786 |
+
part_results_path = (
|
787 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
788 |
+
if is_flickr
|
789 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
790 |
+
)
|
791 |
+
print("load", part_results_path)
|
792 |
+
predictions.extend(json.load(open(part_results_path)))
|
793 |
+
os.remove(part_results_path)
|
794 |
+
print("num:", len(predictions))
|
795 |
+
results_path = (
|
796 |
+
f"flickrresults_{lang_encoder_name}.json"
|
797 |
+
if is_flickr
|
798 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
799 |
+
)
|
800 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
801 |
+
|
802 |
+
metrics = compute_cider(
|
803 |
+
result_path=results_path,
|
804 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
805 |
+
)
|
806 |
+
os.makedirs("eval_results", exist_ok=True)
|
807 |
+
acc = metrics["CIDEr"]
|
808 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
809 |
+
f.write(json.dumps(predictions, indent=2))
|
810 |
+
|
811 |
+
# delete the temporary file
|
812 |
+
os.remove(results_path)
|
813 |
+
else:
|
814 |
+
metrics = {}
|
815 |
+
metrics["CIDEr"] = 0.0
|
816 |
+
|
817 |
+
return metrics["CIDEr"]
|
818 |
+
|
819 |
+
|
820 |
+
def evaluate_vqa(
|
821 |
+
model,
|
822 |
+
tokenizer,
|
823 |
+
image_processor,
|
824 |
+
batch_size,
|
825 |
+
image_dir_path=None,
|
826 |
+
questions_json_path=None,
|
827 |
+
annotations_json_path=None,
|
828 |
+
vqa_dataset="vqa",
|
829 |
+
vis_embed_size=None,
|
830 |
+
rank=0,
|
831 |
+
world_size=1,
|
832 |
+
id=0,
|
833 |
+
):
|
834 |
+
"""
|
835 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
836 |
+
|
837 |
+
Args:
|
838 |
+
model (nn.Module): model to evaluate
|
839 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
840 |
+
image_processor : image processor for the model
|
841 |
+
batch_size (int): batch size
|
842 |
+
image_dir_path (str): path to image directory
|
843 |
+
questions_json_path (str): path to questions json file
|
844 |
+
annotations_json_path (str): path to annotations json file
|
845 |
+
seed (int, optional): random seed. Defaults to 42.
|
846 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
847 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
848 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
849 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
850 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
851 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
852 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
853 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
854 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
855 |
+
Returns:
|
856 |
+
float: accuracy score
|
857 |
+
"""
|
858 |
+
if world_size > 1:
|
859 |
+
torch.distributed.barrier()
|
860 |
+
if vqa_dataset == "gqa":
|
861 |
+
eval_dataset = GQADataset()
|
862 |
+
else:
|
863 |
+
eval_dataset = VQADataset(
|
864 |
+
image_dir_path=image_dir_path,
|
865 |
+
question_path=questions_json_path,
|
866 |
+
annotations_path=annotations_json_path,
|
867 |
+
vqa_dataset=vqa_dataset,
|
868 |
+
)
|
869 |
+
postprocessor = OKVQAPostProcess()
|
870 |
+
try:
|
871 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
872 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
873 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
874 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
875 |
+
except:
|
876 |
+
pass
|
877 |
+
def get_prompt(sample):
|
878 |
+
return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
879 |
+
# return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
880 |
+
|
881 |
+
model.eval().cuda()
|
882 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
883 |
+
if "peft" in lang_encoder_name:
|
884 |
+
lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
|
885 |
+
predictions = []
|
886 |
+
tokenizer.padding_side = "left"
|
887 |
+
if world_size > 1:
|
888 |
+
torch.distributed.barrier()
|
889 |
+
this_tot = 0
|
890 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
891 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
|
892 |
+
)):
|
893 |
+
if ii % world_size != rank:
|
894 |
+
continue
|
895 |
+
batch_images = prepare_batch_images(
|
896 |
+
batch=batch,
|
897 |
+
image_processor=image_processor,
|
898 |
+
).cuda()
|
899 |
+
batch_text = [get_prompt(s) for s in batch]
|
900 |
+
encodings = tokenizer(
|
901 |
+
batch_text,
|
902 |
+
return_tensors="pt",
|
903 |
+
padding="longest",
|
904 |
+
truncation=True,
|
905 |
+
max_length=2000,
|
906 |
+
)
|
907 |
+
input_ids = encodings["input_ids"].cuda()
|
908 |
+
attention_mask = encodings["attention_mask"].cuda()
|
909 |
+
skip_special_tokens = True
|
910 |
+
if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
|
911 |
+
if rank == 0:
|
912 |
+
tqdm.write("use legacy model")
|
913 |
+
for i in range(len(input_ids)):
|
914 |
+
media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
|
915 |
+
endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
|
916 |
+
input_ids[i, media_token_index - 1] = media_token_id
|
917 |
+
input_ids[i, media_token_index] = pad_token_id
|
918 |
+
input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
|
919 |
+
input_ids[i, endofmedia_token_index] = bos_token_id
|
920 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
921 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
922 |
+
image_nums = [1] * len(input_ids)
|
923 |
+
if "llama" in lang_encoder_name:
|
924 |
+
attention_mask[input_ids == 0] = 0
|
925 |
+
outputs = get_outputs(
|
926 |
+
model=model,
|
927 |
+
batch_images=batch_images,
|
928 |
+
attention_mask=attention_mask,
|
929 |
+
max_generation_length=10,
|
930 |
+
min_generation_length=1,
|
931 |
+
num_beams=5,
|
932 |
+
length_penalty=0,
|
933 |
+
input_ids=input_ids,
|
934 |
+
image_start_index_list=image_start_index_list,
|
935 |
+
image_nums=image_nums,
|
936 |
+
)
|
937 |
+
# postprocess begin
|
938 |
+
new_predictions = [
|
939 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
940 |
+
]
|
941 |
+
if vqa_dataset == "ok_vqa":
|
942 |
+
new_predictions = postprocessor._lemmatize(new_predictions)
|
943 |
+
if model.special:
|
944 |
+
for i in range(len(new_predictions)):
|
945 |
+
for answer, _ in Counter(batch[i]['answers']).most_common():
|
946 |
+
if answer in new_predictions[i]:
|
947 |
+
new_predictions[i] = answer
|
948 |
+
break
|
949 |
+
if "cant" in new_predictions[i] and "no" == answer:
|
950 |
+
new_predictions[i] = answer
|
951 |
+
break
|
952 |
+
if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
|
953 |
+
new_predictions[i] = answer
|
954 |
+
break
|
955 |
+
|
956 |
+
this_tot += 1
|
957 |
+
if rank == 0 and this_tot % 20 == 0:
|
958 |
+
for i in range(1):
|
959 |
+
tqdm.write(f"question: {batch[i]['question']}\nanswer: {batch[i]['answers']}model output: " + new_predictions[i])
|
960 |
+
|
961 |
+
predictions.extend(
|
962 |
+
[
|
963 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
964 |
+
for p, sample in zip(new_predictions, batch)
|
965 |
+
]
|
966 |
+
)
|
967 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
968 |
+
f.write(json.dumps(predictions))
|
969 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
970 |
+
|
971 |
+
time.sleep(10)
|
972 |
+
if world_size > 1:
|
973 |
+
torch.distributed.barrier()
|
974 |
+
if rank == 0:
|
975 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
976 |
+
predictions = []
|
977 |
+
for rank_i in range(world_size):
|
978 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
979 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
980 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
981 |
+
print("num:", len(predictions))
|
982 |
+
# save the predictions to a temporary file
|
983 |
+
random_uuid = str(uuid.uuid4())
|
984 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
985 |
+
f.write(json.dumps(predictions, indent=4))
|
986 |
+
|
987 |
+
if vqa_dataset == "gqa":
|
988 |
+
acc = compute_gqa_accuracy(predictions)
|
989 |
+
else:
|
990 |
+
acc = compute_vqa_accuracy(
|
991 |
+
f"{vqa_dataset}results_{random_uuid}.json",
|
992 |
+
questions_json_path,
|
993 |
+
annotations_json_path,
|
994 |
+
vqa_dataset=vqa_dataset,
|
995 |
+
)
|
996 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
997 |
+
os.makedirs("eval_results", exist_ok=True)
|
998 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
999 |
+
f.write(json.dumps(predictions, indent=2))
|
1000 |
+
|
1001 |
+
# delete the temporary file
|
1002 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
1003 |
+
else:
|
1004 |
+
time.sleep(5)
|
1005 |
+
acc = 0.0
|
1006 |
+
if world_size > 1:
|
1007 |
+
torch.distributed.barrier()
|
1008 |
+
return acc
|
1009 |
+
|
1010 |
+
|
1011 |
+
def evaluate_refcoco(
|
1012 |
+
model,
|
1013 |
+
tokenizer,
|
1014 |
+
image_processor,
|
1015 |
+
batch_size,
|
1016 |
+
tsvfile,
|
1017 |
+
max_generation_length=20,
|
1018 |
+
num_beams=3,
|
1019 |
+
length_penalty=-2.0,
|
1020 |
+
device=-1,
|
1021 |
+
vis_embed_size=None,
|
1022 |
+
rank=0,
|
1023 |
+
world_size=1,
|
1024 |
+
id=0,
|
1025 |
+
):
|
1026 |
+
model.eval().cuda()
|
1027 |
+
loc_token_ids = []
|
1028 |
+
for i in range(1000):
|
1029 |
+
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
|
1030 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1031 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1032 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
1033 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
1034 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
1035 |
+
# all_ids = set(range(model.lang_encoder.lm_head.out_features))
|
1036 |
+
# bad_words_ids = list(all_ids - set(loc_token_ids))
|
1037 |
+
# bad_words_ids = [[b] for b in bad_words_ids]
|
1038 |
+
# min_loc_token_id = min(loc_token_ids)
|
1039 |
+
# max_loc_token_id = max(loc_token_ids)
|
1040 |
+
total = 0
|
1041 |
+
correct = 0
|
1042 |
+
ious = []
|
1043 |
+
if "refcocog" in tsvfile:
|
1044 |
+
dataset_name = "refcocog"
|
1045 |
+
elif "refcocoplus" in tsvfile:
|
1046 |
+
dataset_name = "refcocoplus"
|
1047 |
+
else:
|
1048 |
+
dataset_name = "refcoco"
|
1049 |
+
with open(tsvfile, "r") as f:
|
1050 |
+
lines = f.readlines()
|
1051 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
1052 |
+
for ii, line in enumerate(pbar):
|
1053 |
+
if ii % world_size != rank:
|
1054 |
+
continue
|
1055 |
+
total += 1
|
1056 |
+
line = line.rstrip()
|
1057 |
+
uniq_id, image_id, text, region_coord, image = line.split("\t")
|
1058 |
+
|
1059 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
1060 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
|
1061 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
|
1062 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
|
1063 |
+
|
1064 |
+
gt_box = np.array(list(map(float, region_coord.split(","))))
|
1065 |
+
width = image.width
|
1066 |
+
height = image.height
|
1067 |
+
image = image.resize((224, 224))
|
1068 |
+
gt_box = gt_box / np.array([width, height, width, height]) * 224
|
1069 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1070 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
|
1071 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
|
1072 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
|
1073 |
+
# prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
|
1074 |
+
|
1075 |
+
|
1076 |
+
encodings = tokenizer(
|
1077 |
+
prompt,
|
1078 |
+
padding="longest",
|
1079 |
+
truncation=True,
|
1080 |
+
return_tensors="pt",
|
1081 |
+
max_length=2000,
|
1082 |
+
)
|
1083 |
+
input_ids = encodings["input_ids"]
|
1084 |
+
attention_mask = encodings["attention_mask"]
|
1085 |
+
# attention_mask[input_ids == prebox_token_id] = 0
|
1086 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1087 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1088 |
+
image_nums = [1] * len(input_ids)
|
1089 |
+
vision_x = batch_images.cuda()
|
1090 |
+
lang_x = input_ids.cuda()
|
1091 |
+
attention_mask = attention_mask.cuda()
|
1092 |
+
|
1093 |
+
model.debug_id = 0
|
1094 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
1095 |
+
outputs = model(
|
1096 |
+
vision_x=vision_x,
|
1097 |
+
lang_x=lang_x,
|
1098 |
+
attention_mask=attention_mask,
|
1099 |
+
labels=None,
|
1100 |
+
image_nums=image_nums,
|
1101 |
+
image_start_index_list=image_start_index_list,
|
1102 |
+
added_bbox_list=None,
|
1103 |
+
add_box=False,
|
1104 |
+
)
|
1105 |
+
boxes = outputs["boxes"]
|
1106 |
+
scores = outputs["scores"]
|
1107 |
+
if len(scores) > 0:
|
1108 |
+
box = boxes[scores.argmax()]
|
1109 |
+
iou = get_iou(box, gt_box)
|
1110 |
+
else:
|
1111 |
+
iou = 0.0
|
1112 |
+
# tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
|
1113 |
+
tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
|
1114 |
+
if iou >= 0.5:
|
1115 |
+
correct += 1
|
1116 |
+
pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
|
1117 |
+
# open_cv_image = np.array(image)
|
1118 |
+
# # Convert RGB to BGR
|
1119 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1120 |
+
# for box, score in zip(boxes, scores):
|
1121 |
+
# open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
|
1122 |
+
# cv2.imwrite("output.jpg", open_cv_image)
|
1123 |
+
# print(boxes)
|
1124 |
+
# print(scores)
|
1125 |
+
# exit()
|
1126 |
+
|
1127 |
+
|
1128 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1129 |
+
f.write(json.dumps([total, correct]))
|
1130 |
+
if world_size > 1:
|
1131 |
+
torch.distributed.barrier()
|
1132 |
+
if rank == 0:
|
1133 |
+
total = 0
|
1134 |
+
correct = 0
|
1135 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1136 |
+
for rank_i in range(world_size):
|
1137 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1138 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1139 |
+
total += total_part
|
1140 |
+
correct += correct_part
|
1141 |
+
score = correct / total
|
1142 |
+
print("score:", score)
|
1143 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1144 |
+
pass
|
1145 |
+
else:
|
1146 |
+
score = 0.0
|
1147 |
+
if world_size > 1:
|
1148 |
+
torch.distributed.barrier()
|
1149 |
+
return score
|
1150 |
+
|
1151 |
+
|
1152 |
+
def preprocess_visual_info(Text):
|
1153 |
+
text = Text.split(" ")
|
1154 |
+
for is_idx, t in enumerate(text):
|
1155 |
+
if t == "is":
|
1156 |
+
break
|
1157 |
+
the_idx = is_idx
|
1158 |
+
while text[the_idx] != "the":
|
1159 |
+
the_idx -= 1
|
1160 |
+
obj_A = " ".join(text[the_idx+1:is_idx])
|
1161 |
+
second_the_idx = len(text) - 1
|
1162 |
+
while text[second_the_idx] != "the":
|
1163 |
+
second_the_idx -= 1
|
1164 |
+
obj_B = " ".join(text[second_the_idx+1:])
|
1165 |
+
relation = " ".join(text[is_idx+1:second_the_idx])
|
1166 |
+
visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
|
1167 |
+
visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
|
1168 |
+
Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
|
1169 |
+
return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
|
1170 |
+
|
1171 |
+
|
1172 |
+
|
1173 |
+
def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox, debug=False, return_all=False):
|
1174 |
+
assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
|
1175 |
+
encodings = tokenizer(
|
1176 |
+
prompt,
|
1177 |
+
padding="longest",
|
1178 |
+
truncation=True,
|
1179 |
+
return_tensors="pt",
|
1180 |
+
max_length=2000,
|
1181 |
+
)
|
1182 |
+
input_ids = encodings["input_ids"]
|
1183 |
+
attention_mask = encodings["attention_mask"]
|
1184 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1185 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1186 |
+
image_nums = [1] * len(input_ids)
|
1187 |
+
vision_x = batch_images.cuda()
|
1188 |
+
lang_x = input_ids.cuda()
|
1189 |
+
attention_mask = attention_mask.cuda()
|
1190 |
+
prebox_mask = (input_ids == prebox_token_id)
|
1191 |
+
if mask_prebox and prebox_mask.any():
|
1192 |
+
attention_mask[prebox_mask] = 0
|
1193 |
+
|
1194 |
+
model.debug_id = 0
|
1195 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
1196 |
+
outputs = model(
|
1197 |
+
vision_x=vision_x,
|
1198 |
+
lang_x=lang_x,
|
1199 |
+
attention_mask=attention_mask,
|
1200 |
+
labels=None,
|
1201 |
+
image_nums=image_nums,
|
1202 |
+
image_start_index_list=image_start_index_list,
|
1203 |
+
added_bbox_list=visual_box_list,
|
1204 |
+
add_box=visual_box_list is not None,
|
1205 |
+
relations=None,
|
1206 |
+
debug_mode=False,
|
1207 |
+
)
|
1208 |
+
boxes = outputs["boxes"]
|
1209 |
+
scores = outputs["scores"]
|
1210 |
+
if debug:
|
1211 |
+
import pdb; pdb.set_trace()
|
1212 |
+
if return_all:
|
1213 |
+
return boxes, scores
|
1214 |
+
if len(scores) == 0:
|
1215 |
+
return None, None
|
1216 |
+
else:
|
1217 |
+
return boxes[scores.argmax()], scores.max()
|
1218 |
+
|
1219 |
+
|
1220 |
+
def evaluate_aro(
|
1221 |
+
model,
|
1222 |
+
tokenizer,
|
1223 |
+
image_processor,
|
1224 |
+
batch_size,
|
1225 |
+
tsvfile,
|
1226 |
+
max_generation_length=20,
|
1227 |
+
num_beams=3,
|
1228 |
+
length_penalty=-2.0,
|
1229 |
+
device=-1,
|
1230 |
+
vis_embed_size=None,
|
1231 |
+
rank=0,
|
1232 |
+
world_size=1,
|
1233 |
+
id=0,
|
1234 |
+
add_visual=True,
|
1235 |
+
add_relation=False,
|
1236 |
+
subset=True,
|
1237 |
+
choose_left_right=True,
|
1238 |
+
):
|
1239 |
+
os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
|
1240 |
+
from groundingdino.demo.caption_grounder import caption_grounder
|
1241 |
+
generator = caption_grounder(
|
1242 |
+
config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
1243 |
+
checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
1244 |
+
cpu_only=False,
|
1245 |
+
box_threshold=0.1, text_threshold=0.1,
|
1246 |
+
)
|
1247 |
+
dataset_name = "aro"
|
1248 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1249 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1250 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1251 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1252 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1253 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1254 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
1255 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
1256 |
+
model.eval().cuda()
|
1257 |
+
total = 0
|
1258 |
+
correct = 0
|
1259 |
+
from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
|
1260 |
+
vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
|
1261 |
+
with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f:
|
1262 |
+
all_labels = json.load(f)
|
1263 |
+
label_ids = tokenizer(all_labels).input_ids
|
1264 |
+
label_ids = sorted(list(set([x[0] for x in label_ids])))
|
1265 |
+
|
1266 |
+
if subset:
|
1267 |
+
subset_idx = json.load(open("aro_subset.json"))
|
1268 |
+
pbar = tqdm(subset_idx, disable=(rank != 0))
|
1269 |
+
else:
|
1270 |
+
pbar = tqdm(vgr_dataset, disable=(rank != 0))
|
1271 |
+
|
1272 |
+
|
1273 |
+
exist_total = 0
|
1274 |
+
for ii, sample in enumerate(pbar):
|
1275 |
+
if subset:
|
1276 |
+
ORI_IDX = int(sample)
|
1277 |
+
sample = vgr_dataset[sample]
|
1278 |
+
# if ORI_IDX != 19036:
|
1279 |
+
# continue
|
1280 |
+
if ii % world_size != rank:
|
1281 |
+
continue
|
1282 |
+
|
1283 |
+
not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
|
1284 |
+
if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
|
1285 |
+
if rank == 0:
|
1286 |
+
tqdm.write(f"SKIP: {sample['caption_options'][1]}")
|
1287 |
+
continue
|
1288 |
+
total += 1
|
1289 |
+
image = sample["image_options"][0]
|
1290 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
|
1291 |
+
image = image.resize((224, 224))
|
1292 |
+
|
1293 |
+
chosen_idx = 0
|
1294 |
+
text = sample["caption_options"][chosen_idx] # 1 is true caption
|
1295 |
+
# text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
|
1296 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1297 |
+
text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
|
1298 |
+
|
1299 |
+
|
1300 |
+
first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
|
1301 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
|
1302 |
+
first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False)
|
1303 |
+
|
1304 |
+
|
1305 |
+
# use grounding DINO to get the first bbox
|
1306 |
+
# caption = f"{obj_A}"
|
1307 |
+
# with torch.no_grad():
|
1308 |
+
# logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
|
1309 |
+
# boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
|
1310 |
+
# objects = {}
|
1311 |
+
# for box, phrase in zip(boxes_filt, pred_phrases):
|
1312 |
+
# obj, score = phrase
|
1313 |
+
# obj = obj[0]
|
1314 |
+
# if obj not in objects:
|
1315 |
+
# objects[obj] = (score, box)
|
1316 |
+
# if objects[obj][0] < score:
|
1317 |
+
# objects[obj] = (score, box)
|
1318 |
+
# try:
|
1319 |
+
# first_box = objects[obj_A][1].clone()
|
1320 |
+
# first_box[:2] -= first_box[2:] / 2
|
1321 |
+
# first_box[2:] += first_box[:2]
|
1322 |
+
# first_box = first_box.clamp(0, 0.99) * 224.0
|
1323 |
+
# first_box = first_box.numpy()
|
1324 |
+
# first_score = objects[obj_A][0]
|
1325 |
+
# except:
|
1326 |
+
# first_box = None
|
1327 |
+
|
1328 |
+
if first_box is None:
|
1329 |
+
text_A = "the " + obj_A
|
1330 |
+
added_bbox_list = None
|
1331 |
+
else:
|
1332 |
+
text_A = visual_obj_A
|
1333 |
+
added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
1334 |
+
|
1335 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
|
1336 |
+
pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
|
1337 |
+
prebox_token_id, mask_prebox=False, debug=False, return_all=True)
|
1338 |
+
|
1339 |
+
|
1340 |
+
open_cv_image = np.array(image)
|
1341 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1342 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
1343 |
+
fontScale = 0.5
|
1344 |
+
color = (0, 0, 0)
|
1345 |
+
thickness = 1
|
1346 |
+
if first_box is not None:
|
1347 |
+
open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
|
1348 |
+
exist_flag = False
|
1349 |
+
for box, score in zip(pre_boxes, pre_scores):
|
1350 |
+
if score >= 0.5:
|
1351 |
+
exist_flag = True
|
1352 |
+
open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2)
|
1353 |
+
org = box[:2].astype(int)
|
1354 |
+
org[1] += 20
|
1355 |
+
org[0] += 10
|
1356 |
+
open_cv_image = cv2.putText(open_cv_image, f"{score:.2f}", org, font, fontScale, (255, 255, 255), thickness, cv2.LINE_AA)
|
1357 |
+
open_cv_image = cv2.resize(open_cv_image, (512, 512))
|
1358 |
+
put_text = sample["caption_options"][chosen_idx]
|
1359 |
+
org = [10, 20]
|
1360 |
+
open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1361 |
+
# cv2.imwrite(f"visualization/aro_results_{id}/{str(ORI_IDX).zfill(8)}.jpg", open_cv_image)
|
1362 |
+
if exist_flag:
|
1363 |
+
exist_total += 1
|
1364 |
+
continue
|
1365 |
+
|
1366 |
+
|
1367 |
+
|
1368 |
+
if pre_boxes is None:
|
1369 |
+
pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
|
1370 |
+
pre_scores = [1.0]
|
1371 |
+
|
1372 |
+
rank_list = []
|
1373 |
+
# pre_boxes = [pre_boxes[0]]
|
1374 |
+
# pre_scores = [pre_scores[0]]
|
1375 |
+
for pre_box, pre_score in zip(pre_boxes, pre_scores):
|
1376 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
|
1377 |
+
|
1378 |
+
encodings = tokenizer(
|
1379 |
+
prompt,
|
1380 |
+
padding="longest",
|
1381 |
+
truncation=True,
|
1382 |
+
return_tensors="pt",
|
1383 |
+
max_length=512,
|
1384 |
+
)
|
1385 |
+
input_ids = encodings["input_ids"]
|
1386 |
+
attention_mask = encodings["attention_mask"]
|
1387 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1388 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1389 |
+
image_nums = [1] * len(input_ids)
|
1390 |
+
vision_x = batch_images.cuda()
|
1391 |
+
lang_x = input_ids.cuda()
|
1392 |
+
attention_mask = attention_mask.cuda()
|
1393 |
+
labels = lang_x.clone()
|
1394 |
+
|
1395 |
+
answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1
|
1396 |
+
# pre_box = None
|
1397 |
+
labels[0, :answer_start_idx] = -100
|
1398 |
+
# # labels[labels == endofobject_token_id] = -100
|
1399 |
+
# labels[:, 0] = -100
|
1400 |
+
# labels[labels == visual_token_id] = -100
|
1401 |
+
# labels[labels == box_token_id] = -100
|
1402 |
+
# labels[labels == previsual_token_id] = -100
|
1403 |
+
# labels[labels == prebox_token_id] = -100
|
1404 |
+
# labels[labels == endofattr_token_id] = -100
|
1405 |
+
# labels[labels == tokenizer.pad_token_id] = -100
|
1406 |
+
# labels[labels == media_token_id] = -100
|
1407 |
+
# labels[labels == endofmedia_token_id] = -100
|
1408 |
+
answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"]
|
1409 |
+
labels[input_ids == visual_token_id] = -100
|
1410 |
+
labels[input_ids == box_token_id] = -100
|
1411 |
+
labels[input_ids == endofattr_token_id] = -100
|
1412 |
+
labels[input_ids == previsual_token_id] = -100
|
1413 |
+
labels[input_ids == prebox_token_id] = -100
|
1414 |
+
labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
|
1415 |
+
labels[torch.roll(input_ids == box_token_id, 1)] = -100
|
1416 |
+
labels[:, 0] = -100
|
1417 |
+
labels[input_ids == tokenizer.pad_token_id] = -100
|
1418 |
+
labels[input_ids == media_token_id] = -100
|
1419 |
+
labels[input_ids == endofmedia_token_id] = -100
|
1420 |
+
|
1421 |
+
added_bbox_list = None
|
1422 |
+
if add_visual:
|
1423 |
+
added_bbox_list = []
|
1424 |
+
if first_box is not None:
|
1425 |
+
added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
|
1426 |
+
if pre_box is not None:
|
1427 |
+
added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
|
1428 |
+
if added_bbox_list is not None and len(added_bbox_list) == 0:
|
1429 |
+
added_bbox_list = None
|
1430 |
+
|
1431 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1432 |
+
outputs = model(
|
1433 |
+
vision_x=vision_x,
|
1434 |
+
lang_x=lang_x,
|
1435 |
+
attention_mask=attention_mask,
|
1436 |
+
labels=labels,
|
1437 |
+
image_nums=image_nums,
|
1438 |
+
image_start_index_list=image_start_index_list,
|
1439 |
+
added_bbox_list=added_bbox_list,
|
1440 |
+
add_box=added_bbox_list is not None,
|
1441 |
+
relations=None,
|
1442 |
+
)
|
1443 |
+
logits = outputs["logits"][0, answer_start_idx:]
|
1444 |
+
_rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0]))
|
1445 |
+
rank_list.append(_rank)
|
1446 |
+
# open_cv_image = np.array(image)
|
1447 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
1448 |
+
# if first_box is not None:
|
1449 |
+
# open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
|
1450 |
+
# if pre_box is not None:
|
1451 |
+
# open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
|
1452 |
+
|
1453 |
+
# font = cv2.FONT_HERSHEY_SIMPLEX
|
1454 |
+
# org = [10, 20]
|
1455 |
+
# fontScale = 0.5
|
1456 |
+
# color = (0, 0, 0)
|
1457 |
+
# thickness = 1
|
1458 |
+
# open_cv_image = cv2.resize(open_cv_image, (512, 512))
|
1459 |
+
# put_text = sample["caption_options"][1]
|
1460 |
+
# open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1461 |
+
# org[1] += 20
|
1462 |
+
# put_text = "top10 in green box"
|
1463 |
+
# open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1464 |
+
# fontScale = 1.0
|
1465 |
+
# thickness = 2
|
1466 |
+
# for ind in logits_list[i][0].sort(descending=True).indices[:10]:
|
1467 |
+
# org[1] += 20
|
1468 |
+
# put_text = f"{tokenizer.decode(ind)}"
|
1469 |
+
# open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
1470 |
+
# tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}")
|
1471 |
+
# tqdm.write(f"{rank_list}")
|
1472 |
+
final_rank = min(rank_list)
|
1473 |
+
if final_rank < 10:
|
1474 |
+
correct += 1
|
1475 |
+
TYPE = "CORRECT"
|
1476 |
+
if rank == 0:
|
1477 |
+
tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
|
1478 |
+
else:
|
1479 |
+
TYPE = "WRONG"
|
1480 |
+
if rank == 0:
|
1481 |
+
tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
|
1482 |
+
# cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image)
|
1483 |
+
pbar.set_description(f"score: {correct / total:.4f} | {final_rank}")
|
1484 |
+
|
1485 |
+
|
1486 |
+
|
1487 |
+
|
1488 |
+
|
1489 |
+
print(exist_total)
|
1490 |
+
exit()
|
1491 |
+
|
1492 |
+
|
1493 |
+
|
1494 |
+
|
1495 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1496 |
+
f.write(json.dumps([total, correct]))
|
1497 |
+
if world_size > 1:
|
1498 |
+
torch.distributed.barrier()
|
1499 |
+
if rank == 0:
|
1500 |
+
total = 0
|
1501 |
+
correct = 0
|
1502 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1503 |
+
for rank_i in range(world_size):
|
1504 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1505 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1506 |
+
total += total_part
|
1507 |
+
correct += correct_part
|
1508 |
+
score = correct / total
|
1509 |
+
print("score:", score, "total:", total)
|
1510 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1511 |
+
pass
|
1512 |
+
else:
|
1513 |
+
score = 0.0
|
1514 |
+
if world_size > 1:
|
1515 |
+
torch.distributed.barrier()
|
1516 |
+
return score
|
1517 |
+
|
1518 |
+
|
1519 |
+
|
1520 |
+
|
1521 |
+
def evaluate_aro_ori(
|
1522 |
+
model,
|
1523 |
+
tokenizer,
|
1524 |
+
image_processor,
|
1525 |
+
batch_size,
|
1526 |
+
tsvfile,
|
1527 |
+
max_generation_length=20,
|
1528 |
+
num_beams=3,
|
1529 |
+
length_penalty=-2.0,
|
1530 |
+
device=-1,
|
1531 |
+
vis_embed_size=None,
|
1532 |
+
rank=0,
|
1533 |
+
world_size=1,
|
1534 |
+
id=0,
|
1535 |
+
add_visual=True,
|
1536 |
+
add_relation=False,
|
1537 |
+
subset=True,
|
1538 |
+
choose_left_right=True,
|
1539 |
+
only_highest=True,
|
1540 |
+
):
|
1541 |
+
os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
|
1542 |
+
dataset_name = "aroori"
|
1543 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1544 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1545 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1546 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1547 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1548 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1549 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
1550 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
1551 |
+
model.eval().cuda()
|
1552 |
+
total = 0
|
1553 |
+
correct = 0
|
1554 |
+
from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
|
1555 |
+
vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
|
1556 |
+
if subset:
|
1557 |
+
subset_idx = json.load(open("aro_subset.json"))
|
1558 |
+
pbar = tqdm(subset_idx, disable=(rank != 0))
|
1559 |
+
else:
|
1560 |
+
pbar = tqdm(vgr_dataset, disable=(rank != 0))
|
1561 |
+
for ii, sample in enumerate(pbar):
|
1562 |
+
if subset:
|
1563 |
+
ORI_IDX = int(sample)
|
1564 |
+
sample = vgr_dataset[sample]
|
1565 |
+
# if ORI_IDX != 19036:
|
1566 |
+
# continue
|
1567 |
+
if ii % world_size != rank:
|
1568 |
+
continue
|
1569 |
+
|
1570 |
+
not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
|
1571 |
+
if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
|
1572 |
+
if rank == 0:
|
1573 |
+
tqdm.write(f"SKIP: {sample['caption_options'][1]}")
|
1574 |
+
continue
|
1575 |
+
total += 1
|
1576 |
+
image = sample["image_options"][0]
|
1577 |
+
# image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
|
1578 |
+
image = image.resize((224, 224))
|
1579 |
+
debug_data = []
|
1580 |
+
final_losses = []
|
1581 |
+
for idx in range(2):
|
1582 |
+
text = sample["caption_options"][idx] # 1 is true caption
|
1583 |
+
# text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
|
1584 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1585 |
+
text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
|
1586 |
+
first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
|
1587 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
|
1588 |
+
first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False)
|
1589 |
+
if first_box is None:
|
1590 |
+
text_A = "the " + obj_A
|
1591 |
+
added_bbox_list = None
|
1592 |
+
else:
|
1593 |
+
text_A = visual_obj_A
|
1594 |
+
added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
1595 |
+
|
1596 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
|
1597 |
+
pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
|
1598 |
+
prebox_token_id, mask_prebox=False, debug=False, return_all=True)
|
1599 |
+
if pre_boxes is None:
|
1600 |
+
pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
|
1601 |
+
pre_scores = [1.0]
|
1602 |
+
|
1603 |
+
loss_list = []
|
1604 |
+
if only_highest:
|
1605 |
+
pre_boxes = [pre_boxes[0]]
|
1606 |
+
pre_scores = [pre_scores[0]]
|
1607 |
+
for pre_box, pre_score in zip(pre_boxes, pre_scores):
|
1608 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
|
1609 |
+
|
1610 |
+
encodings = tokenizer(
|
1611 |
+
prompt,
|
1612 |
+
padding="longest",
|
1613 |
+
truncation=True,
|
1614 |
+
return_tensors="pt",
|
1615 |
+
max_length=512,
|
1616 |
+
)
|
1617 |
+
input_ids = encodings["input_ids"]
|
1618 |
+
attention_mask = encodings["attention_mask"]
|
1619 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1620 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1621 |
+
image_nums = [1] * len(input_ids)
|
1622 |
+
vision_x = batch_images.cuda()
|
1623 |
+
lang_x = input_ids.cuda()
|
1624 |
+
attention_mask = attention_mask.cuda()
|
1625 |
+
labels = lang_x.clone()
|
1626 |
+
|
1627 |
+
|
1628 |
+
labels[input_ids == visual_token_id] = -100
|
1629 |
+
labels[input_ids == box_token_id] = -100
|
1630 |
+
labels[input_ids == endofattr_token_id] = -100
|
1631 |
+
labels[input_ids == previsual_token_id] = -100
|
1632 |
+
labels[input_ids == prebox_token_id] = -100
|
1633 |
+
labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
|
1634 |
+
labels[torch.roll(input_ids == box_token_id, 1)] = -100
|
1635 |
+
labels[:, 0] = -100
|
1636 |
+
labels[input_ids == tokenizer.pad_token_id] = -100
|
1637 |
+
labels[input_ids == media_token_id] = -100
|
1638 |
+
labels[input_ids == endofmedia_token_id] = -100
|
1639 |
+
|
1640 |
+
added_bbox_list = None
|
1641 |
+
if add_visual:
|
1642 |
+
added_bbox_list = []
|
1643 |
+
if first_box is not None:
|
1644 |
+
added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
|
1645 |
+
if pre_box is not None:
|
1646 |
+
added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
|
1647 |
+
if added_bbox_list is not None and len(added_bbox_list) == 0:
|
1648 |
+
added_bbox_list = None
|
1649 |
+
|
1650 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1651 |
+
outputs = model(
|
1652 |
+
vision_x=vision_x,
|
1653 |
+
lang_x=lang_x,
|
1654 |
+
attention_mask=attention_mask,
|
1655 |
+
labels=labels,
|
1656 |
+
image_nums=image_nums,
|
1657 |
+
image_start_index_list=image_start_index_list,
|
1658 |
+
added_bbox_list=added_bbox_list,
|
1659 |
+
add_box=added_bbox_list is not None,
|
1660 |
+
relations=None,
|
1661 |
+
)
|
1662 |
+
loss_list.append((outputs["loss"].sum() / (outputs["loss"] != 0).sum()).item())
|
1663 |
+
debug_data.append([outputs, first_box, first_score, pre_box, pre_scores])
|
1664 |
+
final_loss = min(loss_list)
|
1665 |
+
final_losses.append(final_loss)
|
1666 |
+
if final_losses[0] >= final_losses[1]:
|
1667 |
+
correct += 1
|
1668 |
+
else:
|
1669 |
+
import pdb; pdb.set_trace()
|
1670 |
+
pass
|
1671 |
+
pbar.set_description(f"score: {correct / total:.4f} | {final_losses[0]:.2f} vs {final_losses[1]:.2f}")
|
1672 |
+
|
1673 |
+
|
1674 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1675 |
+
f.write(json.dumps([total, correct]))
|
1676 |
+
if world_size > 1:
|
1677 |
+
torch.distributed.barrier()
|
1678 |
+
if rank == 0:
|
1679 |
+
total = 0
|
1680 |
+
correct = 0
|
1681 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1682 |
+
for rank_i in range(world_size):
|
1683 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1684 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1685 |
+
total += total_part
|
1686 |
+
correct += correct_part
|
1687 |
+
score = correct / total
|
1688 |
+
print("score:", score, "total:", total)
|
1689 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
1690 |
+
pass
|
1691 |
+
else:
|
1692 |
+
score = 0.0
|
1693 |
+
if world_size > 1:
|
1694 |
+
torch.distributed.barrier()
|
1695 |
+
return score
|
1696 |
+
|
1697 |
+
|
1698 |
+
def evaluate_pisc(
|
1699 |
+
model,
|
1700 |
+
tokenizer,
|
1701 |
+
image_processor,
|
1702 |
+
batch_size,
|
1703 |
+
tsvfile,
|
1704 |
+
max_generation_length=20,
|
1705 |
+
num_beams=3,
|
1706 |
+
length_penalty=-2.0,
|
1707 |
+
device=-1,
|
1708 |
+
vis_embed_size=None,
|
1709 |
+
rank=0,
|
1710 |
+
world_size=1,
|
1711 |
+
id=0,
|
1712 |
+
add_visual=True,
|
1713 |
+
):
|
1714 |
+
from open_flamingo.train.instruction_template import PISC_TEMPLATES
|
1715 |
+
dataset_name = "pisc"
|
1716 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
1717 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
1718 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
1719 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
1720 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
1721 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
1722 |
+
model.train().cuda()
|
1723 |
+
|
1724 |
+
dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
|
1725 |
+
pbar = tqdm(dataset, disable=(rank != 0))
|
1726 |
+
|
1727 |
+
rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
|
1728 |
+
rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
|
1729 |
+
gt = []
|
1730 |
+
pred_scores = []
|
1731 |
+
for III, sample in enumerate(pbar):
|
1732 |
+
if III % world_size != rank:
|
1733 |
+
continue
|
1734 |
+
image_path, dataset, data = sample
|
1735 |
+
image = Image.open(image_path)
|
1736 |
+
size = image_processor.transforms[0].size
|
1737 |
+
image = image.resize((size, size))
|
1738 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
1739 |
+
boxA = data[0]
|
1740 |
+
boxB = data[1]
|
1741 |
+
gt_relation = data[2]
|
1742 |
+
losses = []
|
1743 |
+
for i_rel, option_rel in enumerate(rel_id_to_type):
|
1744 |
+
text = PISC_TEMPLATES[0].format(relation=option_rel)
|
1745 |
+
added_bbox = [
|
1746 |
+
torch.tensor([boxA]).cuda(),
|
1747 |
+
torch.tensor([boxB]).cuda(),
|
1748 |
+
]
|
1749 |
+
caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
|
1750 |
+
encodings = tokenizer(
|
1751 |
+
caption,
|
1752 |
+
padding="longest",
|
1753 |
+
truncation=True,
|
1754 |
+
return_tensors="pt",
|
1755 |
+
max_length=2000,
|
1756 |
+
)
|
1757 |
+
input_ids = encodings["input_ids"]
|
1758 |
+
attention_mask = encodings["attention_mask"]
|
1759 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
1760 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
1761 |
+
image_nums = [1] * len(input_ids)
|
1762 |
+
vision_x = batch_images.cuda()
|
1763 |
+
lang_x = input_ids.cuda()
|
1764 |
+
attention_mask = attention_mask.cuda()
|
1765 |
+
|
1766 |
+
labels = lang_x.clone()
|
1767 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
1768 |
+
if add_visual:
|
1769 |
+
# endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
|
1770 |
+
# endofattr_next_token_index[1] += 1
|
1771 |
+
# endofattr_next_token_id = labels[endofattr_next_token_index]
|
1772 |
+
# </obj><visual><box></attr>NEXT_WORD
|
1773 |
+
# </obj> predict NEXT_WORD
|
1774 |
+
# <visual><box></attr> predict nothing
|
1775 |
+
labels[labels == visual_token_id] = -100
|
1776 |
+
labels[labels == box_token_id] = -100
|
1777 |
+
labels[labels == endofattr_token_id] = -100
|
1778 |
+
# labels[endofattr_next_token_index] = -100
|
1779 |
+
labels[:, 0] = -100
|
1780 |
+
answer_token_id = tokenizer(" Answer").input_ids[0]
|
1781 |
+
answer_token_loc = (input_ids == answer_token_id).nonzero()
|
1782 |
+
for batch_idx, idx in answer_token_loc:
|
1783 |
+
labels[batch_idx][:idx+2] = -100
|
1784 |
+
|
1785 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
1786 |
+
outputs = model(
|
1787 |
+
vision_x=vision_x,
|
1788 |
+
lang_x=lang_x,
|
1789 |
+
attention_mask=attention_mask,
|
1790 |
+
labels=labels,
|
1791 |
+
image_nums=image_nums,
|
1792 |
+
image_start_index_list=image_start_index_list,
|
1793 |
+
added_bbox_list=added_bbox,
|
1794 |
+
add_box=added_bbox is not None,
|
1795 |
+
)
|
1796 |
+
loss_total = outputs.loss.reshape(labels.shape[0], -1)
|
1797 |
+
loss = loss_total.sum() / (loss_total != 0).sum()
|
1798 |
+
losses.append(loss.item())
|
1799 |
+
pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
|
1800 |
+
gt.append(rel_type_to_id[gt_relation])
|
1801 |
+
gt = np.array(gt)
|
1802 |
+
pred_scores = np.array(pred_scores)
|
1803 |
+
pred = pred_scores.argmax(1)
|
1804 |
+
|
1805 |
+
|
1806 |
+
print("total num:", len(gt))
|
1807 |
+
recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
|
1808 |
+
print("recalls:", recalls)
|
1809 |
+
|
1810 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
1811 |
+
f.write(json.dumps([gt.tolist(), pred.tolist()]))
|
1812 |
+
if world_size > 1:
|
1813 |
+
torch.distributed.barrier()
|
1814 |
+
if rank == 0:
|
1815 |
+
gt = []
|
1816 |
+
pred = []
|
1817 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
1818 |
+
for rank_i in range(world_size):
|
1819 |
+
[gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
1820 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
1821 |
+
gt.extend(gt_part)
|
1822 |
+
pred.extend(pred_part)
|
1823 |
+
print("total num:", len(gt))
|
1824 |
+
recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
|
1825 |
+
print("recalls:", recalls)
|
1826 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
|
1827 |
+
f.write(f"{gt}\n")
|
1828 |
+
f.write(f"{pred}\n")
|
1829 |
+
f.write(f"{recalls}\n")
|
1830 |
+
score = 0.0
|
1831 |
+
if world_size > 1:
|
1832 |
+
torch.distributed.barrier()
|
1833 |
+
return score
|
1834 |
+
|
1835 |
+
|
1836 |
+
|
1837 |
+
if __name__ == "__main__":
|
1838 |
+
main()
|
multimodal/build/lib/open_flamingo/eval/imagenet_utils.py
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1
|
2 |
+
openai_imagenet_classnames = [
|
3 |
+
"tench",
|
4 |
+
"goldfish",
|
5 |
+
"great white shark",
|
6 |
+
"tiger shark",
|
7 |
+
"hammerhead shark",
|
8 |
+
"electric ray",
|
9 |
+
"stingray",
|
10 |
+
"rooster",
|
11 |
+
"hen",
|
12 |
+
"ostrich",
|
13 |
+
"brambling",
|
14 |
+
"goldfinch",
|
15 |
+
"house finch",
|
16 |
+
"junco",
|
17 |
+
"indigo bunting",
|
18 |
+
"American robin",
|
19 |
+
"bulbul",
|
20 |
+
"jay",
|
21 |
+
"magpie",
|
22 |
+
"chickadee",
|
23 |
+
"American dipper",
|
24 |
+
"kite (bird of prey)",
|
25 |
+
"bald eagle",
|
26 |
+
"vulture",
|
27 |
+
"great grey owl",
|
28 |
+
"fire salamander",
|
29 |
+
"smooth newt",
|
30 |
+
"newt",
|
31 |
+
"spotted salamander",
|
32 |
+
"axolotl",
|
33 |
+
"American bullfrog",
|
34 |
+
"tree frog",
|
35 |
+
"tailed frog",
|
36 |
+
"loggerhead sea turtle",
|
37 |
+
"leatherback sea turtle",
|
38 |
+
"mud turtle",
|
39 |
+
"terrapin",
|
40 |
+
"box turtle",
|
41 |
+
"banded gecko",
|
42 |
+
"green iguana",
|
43 |
+
"Carolina anole",
|
44 |
+
"desert grassland whiptail lizard",
|
45 |
+
"agama",
|
46 |
+
"frilled-necked lizard",
|
47 |
+
"alligator lizard",
|
48 |
+
"Gila monster",
|
49 |
+
"European green lizard",
|
50 |
+
"chameleon",
|
51 |
+
"Komodo dragon",
|
52 |
+
"Nile crocodile",
|
53 |
+
"American alligator",
|
54 |
+
"triceratops",
|
55 |
+
"worm snake",
|
56 |
+
"ring-necked snake",
|
57 |
+
"eastern hog-nosed snake",
|
58 |
+
"smooth green snake",
|
59 |
+
"kingsnake",
|
60 |
+
"garter snake",
|
61 |
+
"water snake",
|
62 |
+
"vine snake",
|
63 |
+
"night snake",
|
64 |
+
"boa constrictor",
|
65 |
+
"African rock python",
|
66 |
+
"Indian cobra",
|
67 |
+
"green mamba",
|
68 |
+
"sea snake",
|
69 |
+
"Saharan horned viper",
|
70 |
+
"eastern diamondback rattlesnake",
|
71 |
+
"sidewinder rattlesnake",
|
72 |
+
"trilobite",
|
73 |
+
"harvestman",
|
74 |
+
"scorpion",
|
75 |
+
"yellow garden spider",
|
76 |
+
"barn spider",
|
77 |
+
"European garden spider",
|
78 |
+
"southern black widow",
|
79 |
+
"tarantula",
|
80 |
+
"wolf spider",
|
81 |
+
"tick",
|
82 |
+
"centipede",
|
83 |
+
"black grouse",
|
84 |
+
"ptarmigan",
|
85 |
+
"ruffed grouse",
|
86 |
+
"prairie grouse",
|
87 |
+
"peafowl",
|
88 |
+
"quail",
|
89 |
+
"partridge",
|
90 |
+
"african grey parrot",
|
91 |
+
"macaw",
|
92 |
+
"sulphur-crested cockatoo",
|
93 |
+
"lorikeet",
|
94 |
+
"coucal",
|
95 |
+
"bee eater",
|
96 |
+
"hornbill",
|
97 |
+
"hummingbird",
|
98 |
+
"jacamar",
|
99 |
+
"toucan",
|
100 |
+
"duck",
|
101 |
+
"red-breasted merganser",
|
102 |
+
"goose",
|
103 |
+
"black swan",
|
104 |
+
"tusker",
|
105 |
+
"echidna",
|
106 |
+
"platypus",
|
107 |
+
"wallaby",
|
108 |
+
"koala",
|
109 |
+
"wombat",
|
110 |
+
"jellyfish",
|
111 |
+
"sea anemone",
|
112 |
+
"brain coral",
|
113 |
+
"flatworm",
|
114 |
+
"nematode",
|
115 |
+
"conch",
|
116 |
+
"snail",
|
117 |
+
"slug",
|
118 |
+
"sea slug",
|
119 |
+
"chiton",
|
120 |
+
"chambered nautilus",
|
121 |
+
"Dungeness crab",
|
122 |
+
"rock crab",
|
123 |
+
"fiddler crab",
|
124 |
+
"red king crab",
|
125 |
+
"American lobster",
|
126 |
+
"spiny lobster",
|
127 |
+
"crayfish",
|
128 |
+
"hermit crab",
|
129 |
+
"isopod",
|
130 |
+
"white stork",
|
131 |
+
"black stork",
|
132 |
+
"spoonbill",
|
133 |
+
"flamingo",
|
134 |
+
"little blue heron",
|
135 |
+
"great egret",
|
136 |
+
"bittern bird",
|
137 |
+
"crane bird",
|
138 |
+
"limpkin",
|
139 |
+
"common gallinule",
|
140 |
+
"American coot",
|
141 |
+
"bustard",
|
142 |
+
"ruddy turnstone",
|
143 |
+
"dunlin",
|
144 |
+
"common redshank",
|
145 |
+
"dowitcher",
|
146 |
+
"oystercatcher",
|
147 |
+
"pelican",
|
148 |
+
"king penguin",
|
149 |
+
"albatross",
|
150 |
+
"grey whale",
|
151 |
+
"killer whale",
|
152 |
+
"dugong",
|
153 |
+
"sea lion",
|
154 |
+
"Chihuahua",
|
155 |
+
"Japanese Chin",
|
156 |
+
"Maltese",
|
157 |
+
"Pekingese",
|
158 |
+
"Shih Tzu",
|
159 |
+
"King Charles Spaniel",
|
160 |
+
"Papillon",
|
161 |
+
"toy terrier",
|
162 |
+
"Rhodesian Ridgeback",
|
163 |
+
"Afghan Hound",
|
164 |
+
"Basset Hound",
|
165 |
+
"Beagle",
|
166 |
+
"Bloodhound",
|
167 |
+
"Bluetick Coonhound",
|
168 |
+
"Black and Tan Coonhound",
|
169 |
+
"Treeing Walker Coonhound",
|
170 |
+
"English foxhound",
|
171 |
+
"Redbone Coonhound",
|
172 |
+
"borzoi",
|
173 |
+
"Irish Wolfhound",
|
174 |
+
"Italian Greyhound",
|
175 |
+
"Whippet",
|
176 |
+
"Ibizan Hound",
|
177 |
+
"Norwegian Elkhound",
|
178 |
+
"Otterhound",
|
179 |
+
"Saluki",
|
180 |
+
"Scottish Deerhound",
|
181 |
+
"Weimaraner",
|
182 |
+
"Staffordshire Bull Terrier",
|
183 |
+
"American Staffordshire Terrier",
|
184 |
+
"Bedlington Terrier",
|
185 |
+
"Border Terrier",
|
186 |
+
"Kerry Blue Terrier",
|
187 |
+
"Irish Terrier",
|
188 |
+
"Norfolk Terrier",
|
189 |
+
"Norwich Terrier",
|
190 |
+
"Yorkshire Terrier",
|
191 |
+
"Wire Fox Terrier",
|
192 |
+
"Lakeland Terrier",
|
193 |
+
"Sealyham Terrier",
|
194 |
+
"Airedale Terrier",
|
195 |
+
"Cairn Terrier",
|
196 |
+
"Australian Terrier",
|
197 |
+
"Dandie Dinmont Terrier",
|
198 |
+
"Boston Terrier",
|
199 |
+
"Miniature Schnauzer",
|
200 |
+
"Giant Schnauzer",
|
201 |
+
"Standard Schnauzer",
|
202 |
+
"Scottish Terrier",
|
203 |
+
"Tibetan Terrier",
|
204 |
+
"Australian Silky Terrier",
|
205 |
+
"Soft-coated Wheaten Terrier",
|
206 |
+
"West Highland White Terrier",
|
207 |
+
"Lhasa Apso",
|
208 |
+
"Flat-Coated Retriever",
|
209 |
+
"Curly-coated Retriever",
|
210 |
+
"Golden Retriever",
|
211 |
+
"Labrador Retriever",
|
212 |
+
"Chesapeake Bay Retriever",
|
213 |
+
"German Shorthaired Pointer",
|
214 |
+
"Vizsla",
|
215 |
+
"English Setter",
|
216 |
+
"Irish Setter",
|
217 |
+
"Gordon Setter",
|
218 |
+
"Brittany dog",
|
219 |
+
"Clumber Spaniel",
|
220 |
+
"English Springer Spaniel",
|
221 |
+
"Welsh Springer Spaniel",
|
222 |
+
"Cocker Spaniel",
|
223 |
+
"Sussex Spaniel",
|
224 |
+
"Irish Water Spaniel",
|
225 |
+
"Kuvasz",
|
226 |
+
"Schipperke",
|
227 |
+
"Groenendael dog",
|
228 |
+
"Malinois",
|
229 |
+
"Briard",
|
230 |
+
"Australian Kelpie",
|
231 |
+
"Komondor",
|
232 |
+
"Old English Sheepdog",
|
233 |
+
"Shetland Sheepdog",
|
234 |
+
"collie",
|
235 |
+
"Border Collie",
|
236 |
+
"Bouvier des Flandres dog",
|
237 |
+
"Rottweiler",
|
238 |
+
"German Shepherd Dog",
|
239 |
+
"Dobermann",
|
240 |
+
"Miniature Pinscher",
|
241 |
+
"Greater Swiss Mountain Dog",
|
242 |
+
"Bernese Mountain Dog",
|
243 |
+
"Appenzeller Sennenhund",
|
244 |
+
"Entlebucher Sennenhund",
|
245 |
+
"Boxer",
|
246 |
+
"Bullmastiff",
|
247 |
+
"Tibetan Mastiff",
|
248 |
+
"French Bulldog",
|
249 |
+
"Great Dane",
|
250 |
+
"St. Bernard",
|
251 |
+
"husky",
|
252 |
+
"Alaskan Malamute",
|
253 |
+
"Siberian Husky",
|
254 |
+
"Dalmatian",
|
255 |
+
"Affenpinscher",
|
256 |
+
"Basenji",
|
257 |
+
"pug",
|
258 |
+
"Leonberger",
|
259 |
+
"Newfoundland dog",
|
260 |
+
"Great Pyrenees dog",
|
261 |
+
"Samoyed",
|
262 |
+
"Pomeranian",
|
263 |
+
"Chow Chow",
|
264 |
+
"Keeshond",
|
265 |
+
"brussels griffon",
|
266 |
+
"Pembroke Welsh Corgi",
|
267 |
+
"Cardigan Welsh Corgi",
|
268 |
+
"Toy Poodle",
|
269 |
+
"Miniature Poodle",
|
270 |
+
"Standard Poodle",
|
271 |
+
"Mexican hairless dog (xoloitzcuintli)",
|
272 |
+
"grey wolf",
|
273 |
+
"Alaskan tundra wolf",
|
274 |
+
"red wolf or maned wolf",
|
275 |
+
"coyote",
|
276 |
+
"dingo",
|
277 |
+
"dhole",
|
278 |
+
"African wild dog",
|
279 |
+
"hyena",
|
280 |
+
"red fox",
|
281 |
+
"kit fox",
|
282 |
+
"Arctic fox",
|
283 |
+
"grey fox",
|
284 |
+
"tabby cat",
|
285 |
+
"tiger cat",
|
286 |
+
"Persian cat",
|
287 |
+
"Siamese cat",
|
288 |
+
"Egyptian Mau",
|
289 |
+
"cougar",
|
290 |
+
"lynx",
|
291 |
+
"leopard",
|
292 |
+
"snow leopard",
|
293 |
+
"jaguar",
|
294 |
+
"lion",
|
295 |
+
"tiger",
|
296 |
+
"cheetah",
|
297 |
+
"brown bear",
|
298 |
+
"American black bear",
|
299 |
+
"polar bear",
|
300 |
+
"sloth bear",
|
301 |
+
"mongoose",
|
302 |
+
"meerkat",
|
303 |
+
"tiger beetle",
|
304 |
+
"ladybug",
|
305 |
+
"ground beetle",
|
306 |
+
"longhorn beetle",
|
307 |
+
"leaf beetle",
|
308 |
+
"dung beetle",
|
309 |
+
"rhinoceros beetle",
|
310 |
+
"weevil",
|
311 |
+
"fly",
|
312 |
+
"bee",
|
313 |
+
"ant",
|
314 |
+
"grasshopper",
|
315 |
+
"cricket insect",
|
316 |
+
"stick insect",
|
317 |
+
"cockroach",
|
318 |
+
"praying mantis",
|
319 |
+
"cicada",
|
320 |
+
"leafhopper",
|
321 |
+
"lacewing",
|
322 |
+
"dragonfly",
|
323 |
+
"damselfly",
|
324 |
+
"red admiral butterfly",
|
325 |
+
"ringlet butterfly",
|
326 |
+
"monarch butterfly",
|
327 |
+
"small white butterfly",
|
328 |
+
"sulphur butterfly",
|
329 |
+
"gossamer-winged butterfly",
|
330 |
+
"starfish",
|
331 |
+
"sea urchin",
|
332 |
+
"sea cucumber",
|
333 |
+
"cottontail rabbit",
|
334 |
+
"hare",
|
335 |
+
"Angora rabbit",
|
336 |
+
"hamster",
|
337 |
+
"porcupine",
|
338 |
+
"fox squirrel",
|
339 |
+
"marmot",
|
340 |
+
"beaver",
|
341 |
+
"guinea pig",
|
342 |
+
"common sorrel horse",
|
343 |
+
"zebra",
|
344 |
+
"pig",
|
345 |
+
"wild boar",
|
346 |
+
"warthog",
|
347 |
+
"hippopotamus",
|
348 |
+
"ox",
|
349 |
+
"water buffalo",
|
350 |
+
"bison",
|
351 |
+
"ram (adult male sheep)",
|
352 |
+
"bighorn sheep",
|
353 |
+
"Alpine ibex",
|
354 |
+
"hartebeest",
|
355 |
+
"impala (antelope)",
|
356 |
+
"gazelle",
|
357 |
+
"arabian camel",
|
358 |
+
"llama",
|
359 |
+
"weasel",
|
360 |
+
"mink",
|
361 |
+
"European polecat",
|
362 |
+
"black-footed ferret",
|
363 |
+
"otter",
|
364 |
+
"skunk",
|
365 |
+
"badger",
|
366 |
+
"armadillo",
|
367 |
+
"three-toed sloth",
|
368 |
+
"orangutan",
|
369 |
+
"gorilla",
|
370 |
+
"chimpanzee",
|
371 |
+
"gibbon",
|
372 |
+
"siamang",
|
373 |
+
"guenon",
|
374 |
+
"patas monkey",
|
375 |
+
"baboon",
|
376 |
+
"macaque",
|
377 |
+
"langur",
|
378 |
+
"black-and-white colobus",
|
379 |
+
"proboscis monkey",
|
380 |
+
"marmoset",
|
381 |
+
"white-headed capuchin",
|
382 |
+
"howler monkey",
|
383 |
+
"titi monkey",
|
384 |
+
"Geoffroy's spider monkey",
|
385 |
+
"common squirrel monkey",
|
386 |
+
"ring-tailed lemur",
|
387 |
+
"indri",
|
388 |
+
"Asian elephant",
|
389 |
+
"African bush elephant",
|
390 |
+
"red panda",
|
391 |
+
"giant panda",
|
392 |
+
"snoek fish",
|
393 |
+
"eel",
|
394 |
+
"silver salmon",
|
395 |
+
"rock beauty fish",
|
396 |
+
"clownfish",
|
397 |
+
"sturgeon",
|
398 |
+
"gar fish",
|
399 |
+
"lionfish",
|
400 |
+
"pufferfish",
|
401 |
+
"abacus",
|
402 |
+
"abaya",
|
403 |
+
"academic gown",
|
404 |
+
"accordion",
|
405 |
+
"acoustic guitar",
|
406 |
+
"aircraft carrier",
|
407 |
+
"airliner",
|
408 |
+
"airship",
|
409 |
+
"altar",
|
410 |
+
"ambulance",
|
411 |
+
"amphibious vehicle",
|
412 |
+
"analog clock",
|
413 |
+
"apiary",
|
414 |
+
"apron",
|
415 |
+
"trash can",
|
416 |
+
"assault rifle",
|
417 |
+
"backpack",
|
418 |
+
"bakery",
|
419 |
+
"balance beam",
|
420 |
+
"balloon",
|
421 |
+
"ballpoint pen",
|
422 |
+
"Band-Aid",
|
423 |
+
"banjo",
|
424 |
+
"baluster / handrail",
|
425 |
+
"barbell",
|
426 |
+
"barber chair",
|
427 |
+
"barbershop",
|
428 |
+
"barn",
|
429 |
+
"barometer",
|
430 |
+
"barrel",
|
431 |
+
"wheelbarrow",
|
432 |
+
"baseball",
|
433 |
+
"basketball",
|
434 |
+
"bassinet",
|
435 |
+
"bassoon",
|
436 |
+
"swimming cap",
|
437 |
+
"bath towel",
|
438 |
+
"bathtub",
|
439 |
+
"station wagon",
|
440 |
+
"lighthouse",
|
441 |
+
"beaker",
|
442 |
+
"military hat (bearskin or shako)",
|
443 |
+
"beer bottle",
|
444 |
+
"beer glass",
|
445 |
+
"bell tower",
|
446 |
+
"baby bib",
|
447 |
+
"tandem bicycle",
|
448 |
+
"bikini",
|
449 |
+
"ring binder",
|
450 |
+
"binoculars",
|
451 |
+
"birdhouse",
|
452 |
+
"boathouse",
|
453 |
+
"bobsleigh",
|
454 |
+
"bolo tie",
|
455 |
+
"poke bonnet",
|
456 |
+
"bookcase",
|
457 |
+
"bookstore",
|
458 |
+
"bottle cap",
|
459 |
+
"hunting bow",
|
460 |
+
"bow tie",
|
461 |
+
"brass memorial plaque",
|
462 |
+
"bra",
|
463 |
+
"breakwater",
|
464 |
+
"breastplate",
|
465 |
+
"broom",
|
466 |
+
"bucket",
|
467 |
+
"buckle",
|
468 |
+
"bulletproof vest",
|
469 |
+
"high-speed train",
|
470 |
+
"butcher shop",
|
471 |
+
"taxicab",
|
472 |
+
"cauldron",
|
473 |
+
"candle",
|
474 |
+
"cannon",
|
475 |
+
"canoe",
|
476 |
+
"can opener",
|
477 |
+
"cardigan",
|
478 |
+
"car mirror",
|
479 |
+
"carousel",
|
480 |
+
"tool kit",
|
481 |
+
"cardboard box / carton",
|
482 |
+
"car wheel",
|
483 |
+
"automated teller machine",
|
484 |
+
"cassette",
|
485 |
+
"cassette player",
|
486 |
+
"castle",
|
487 |
+
"catamaran",
|
488 |
+
"CD player",
|
489 |
+
"cello",
|
490 |
+
"mobile phone",
|
491 |
+
"chain",
|
492 |
+
"chain-link fence",
|
493 |
+
"chain mail",
|
494 |
+
"chainsaw",
|
495 |
+
"storage chest",
|
496 |
+
"chiffonier",
|
497 |
+
"bell or wind chime",
|
498 |
+
"china cabinet",
|
499 |
+
"Christmas stocking",
|
500 |
+
"church",
|
501 |
+
"movie theater",
|
502 |
+
"cleaver",
|
503 |
+
"cliff dwelling",
|
504 |
+
"cloak",
|
505 |
+
"clogs",
|
506 |
+
"cocktail shaker",
|
507 |
+
"coffee mug",
|
508 |
+
"coffeemaker",
|
509 |
+
"spiral or coil",
|
510 |
+
"combination lock",
|
511 |
+
"computer keyboard",
|
512 |
+
"candy store",
|
513 |
+
"container ship",
|
514 |
+
"convertible",
|
515 |
+
"corkscrew",
|
516 |
+
"cornet",
|
517 |
+
"cowboy boot",
|
518 |
+
"cowboy hat",
|
519 |
+
"cradle",
|
520 |
+
"construction crane",
|
521 |
+
"crash helmet",
|
522 |
+
"crate",
|
523 |
+
"infant bed",
|
524 |
+
"Crock Pot",
|
525 |
+
"croquet ball",
|
526 |
+
"crutch",
|
527 |
+
"cuirass",
|
528 |
+
"dam",
|
529 |
+
"desk",
|
530 |
+
"desktop computer",
|
531 |
+
"rotary dial telephone",
|
532 |
+
"diaper",
|
533 |
+
"digital clock",
|
534 |
+
"digital watch",
|
535 |
+
"dining table",
|
536 |
+
"dishcloth",
|
537 |
+
"dishwasher",
|
538 |
+
"disc brake",
|
539 |
+
"dock",
|
540 |
+
"dog sled",
|
541 |
+
"dome",
|
542 |
+
"doormat",
|
543 |
+
"drilling rig",
|
544 |
+
"drum",
|
545 |
+
"drumstick",
|
546 |
+
"dumbbell",
|
547 |
+
"Dutch oven",
|
548 |
+
"electric fan",
|
549 |
+
"electric guitar",
|
550 |
+
"electric locomotive",
|
551 |
+
"entertainment center",
|
552 |
+
"envelope",
|
553 |
+
"espresso machine",
|
554 |
+
"face powder",
|
555 |
+
"feather boa",
|
556 |
+
"filing cabinet",
|
557 |
+
"fireboat",
|
558 |
+
"fire truck",
|
559 |
+
"fire screen",
|
560 |
+
"flagpole",
|
561 |
+
"flute",
|
562 |
+
"folding chair",
|
563 |
+
"football helmet",
|
564 |
+
"forklift",
|
565 |
+
"fountain",
|
566 |
+
"fountain pen",
|
567 |
+
"four-poster bed",
|
568 |
+
"freight car",
|
569 |
+
"French horn",
|
570 |
+
"frying pan",
|
571 |
+
"fur coat",
|
572 |
+
"garbage truck",
|
573 |
+
"gas mask or respirator",
|
574 |
+
"gas pump",
|
575 |
+
"goblet",
|
576 |
+
"go-kart",
|
577 |
+
"golf ball",
|
578 |
+
"golf cart",
|
579 |
+
"gondola",
|
580 |
+
"gong",
|
581 |
+
"gown",
|
582 |
+
"grand piano",
|
583 |
+
"greenhouse",
|
584 |
+
"radiator grille",
|
585 |
+
"grocery store",
|
586 |
+
"guillotine",
|
587 |
+
"hair clip",
|
588 |
+
"hair spray",
|
589 |
+
"half-track",
|
590 |
+
"hammer",
|
591 |
+
"hamper",
|
592 |
+
"hair dryer",
|
593 |
+
"hand-held computer",
|
594 |
+
"handkerchief",
|
595 |
+
"hard disk drive",
|
596 |
+
"harmonica",
|
597 |
+
"harp",
|
598 |
+
"combine harvester",
|
599 |
+
"hatchet",
|
600 |
+
"holster",
|
601 |
+
"home theater",
|
602 |
+
"honeycomb",
|
603 |
+
"hook",
|
604 |
+
"hoop skirt",
|
605 |
+
"gymnastic horizontal bar",
|
606 |
+
"horse-drawn vehicle",
|
607 |
+
"hourglass",
|
608 |
+
"iPod",
|
609 |
+
"clothes iron",
|
610 |
+
"carved pumpkin",
|
611 |
+
"jeans",
|
612 |
+
"jeep",
|
613 |
+
"T-shirt",
|
614 |
+
"jigsaw puzzle",
|
615 |
+
"rickshaw",
|
616 |
+
"joystick",
|
617 |
+
"kimono",
|
618 |
+
"knee pad",
|
619 |
+
"knot",
|
620 |
+
"lab coat",
|
621 |
+
"ladle",
|
622 |
+
"lampshade",
|
623 |
+
"laptop computer",
|
624 |
+
"lawn mower",
|
625 |
+
"lens cap",
|
626 |
+
"letter opener",
|
627 |
+
"library",
|
628 |
+
"lifeboat",
|
629 |
+
"lighter",
|
630 |
+
"limousine",
|
631 |
+
"ocean liner",
|
632 |
+
"lipstick",
|
633 |
+
"slip-on shoe",
|
634 |
+
"lotion",
|
635 |
+
"music speaker",
|
636 |
+
"loupe magnifying glass",
|
637 |
+
"sawmill",
|
638 |
+
"magnetic compass",
|
639 |
+
"messenger bag",
|
640 |
+
"mailbox",
|
641 |
+
"tights",
|
642 |
+
"one-piece bathing suit",
|
643 |
+
"manhole cover",
|
644 |
+
"maraca",
|
645 |
+
"marimba",
|
646 |
+
"mask",
|
647 |
+
"matchstick",
|
648 |
+
"maypole",
|
649 |
+
"maze",
|
650 |
+
"measuring cup",
|
651 |
+
"medicine cabinet",
|
652 |
+
"megalith",
|
653 |
+
"microphone",
|
654 |
+
"microwave oven",
|
655 |
+
"military uniform",
|
656 |
+
"milk can",
|
657 |
+
"minibus",
|
658 |
+
"miniskirt",
|
659 |
+
"minivan",
|
660 |
+
"missile",
|
661 |
+
"mitten",
|
662 |
+
"mixing bowl",
|
663 |
+
"mobile home",
|
664 |
+
"ford model t",
|
665 |
+
"modem",
|
666 |
+
"monastery",
|
667 |
+
"monitor",
|
668 |
+
"moped",
|
669 |
+
"mortar and pestle",
|
670 |
+
"graduation cap",
|
671 |
+
"mosque",
|
672 |
+
"mosquito net",
|
673 |
+
"vespa",
|
674 |
+
"mountain bike",
|
675 |
+
"tent",
|
676 |
+
"computer mouse",
|
677 |
+
"mousetrap",
|
678 |
+
"moving van",
|
679 |
+
"muzzle",
|
680 |
+
"metal nail",
|
681 |
+
"neck brace",
|
682 |
+
"necklace",
|
683 |
+
"baby pacifier",
|
684 |
+
"notebook computer",
|
685 |
+
"obelisk",
|
686 |
+
"oboe",
|
687 |
+
"ocarina",
|
688 |
+
"odometer",
|
689 |
+
"oil filter",
|
690 |
+
"pipe organ",
|
691 |
+
"oscilloscope",
|
692 |
+
"overskirt",
|
693 |
+
"bullock cart",
|
694 |
+
"oxygen mask",
|
695 |
+
"product packet / packaging",
|
696 |
+
"paddle",
|
697 |
+
"paddle wheel",
|
698 |
+
"padlock",
|
699 |
+
"paintbrush",
|
700 |
+
"pajamas",
|
701 |
+
"palace",
|
702 |
+
"pan flute",
|
703 |
+
"paper towel",
|
704 |
+
"parachute",
|
705 |
+
"parallel bars",
|
706 |
+
"park bench",
|
707 |
+
"parking meter",
|
708 |
+
"railroad car",
|
709 |
+
"patio",
|
710 |
+
"payphone",
|
711 |
+
"pedestal",
|
712 |
+
"pencil case",
|
713 |
+
"pencil sharpener",
|
714 |
+
"perfume",
|
715 |
+
"Petri dish",
|
716 |
+
"photocopier",
|
717 |
+
"plectrum",
|
718 |
+
"Pickelhaube",
|
719 |
+
"picket fence",
|
720 |
+
"pickup truck",
|
721 |
+
"pier",
|
722 |
+
"piggy bank",
|
723 |
+
"pill bottle",
|
724 |
+
"pillow",
|
725 |
+
"ping-pong ball",
|
726 |
+
"pinwheel",
|
727 |
+
"pirate ship",
|
728 |
+
"drink pitcher",
|
729 |
+
"block plane",
|
730 |
+
"planetarium",
|
731 |
+
"plastic bag",
|
732 |
+
"plate rack",
|
733 |
+
"farm plow",
|
734 |
+
"plunger",
|
735 |
+
"Polaroid camera",
|
736 |
+
"pole",
|
737 |
+
"police van",
|
738 |
+
"poncho",
|
739 |
+
"pool table",
|
740 |
+
"soda bottle",
|
741 |
+
"plant pot",
|
742 |
+
"potter's wheel",
|
743 |
+
"power drill",
|
744 |
+
"prayer rug",
|
745 |
+
"printer",
|
746 |
+
"prison",
|
747 |
+
"missile",
|
748 |
+
"projector",
|
749 |
+
"hockey puck",
|
750 |
+
"punching bag",
|
751 |
+
"purse",
|
752 |
+
"quill",
|
753 |
+
"quilt",
|
754 |
+
"race car",
|
755 |
+
"racket",
|
756 |
+
"radiator",
|
757 |
+
"radio",
|
758 |
+
"radio telescope",
|
759 |
+
"rain barrel",
|
760 |
+
"recreational vehicle",
|
761 |
+
"fishing casting reel",
|
762 |
+
"reflex camera",
|
763 |
+
"refrigerator",
|
764 |
+
"remote control",
|
765 |
+
"restaurant",
|
766 |
+
"revolver",
|
767 |
+
"rifle",
|
768 |
+
"rocking chair",
|
769 |
+
"rotisserie",
|
770 |
+
"eraser",
|
771 |
+
"rugby ball",
|
772 |
+
"ruler measuring stick",
|
773 |
+
"sneaker",
|
774 |
+
"safe",
|
775 |
+
"safety pin",
|
776 |
+
"salt shaker",
|
777 |
+
"sandal",
|
778 |
+
"sarong",
|
779 |
+
"saxophone",
|
780 |
+
"scabbard",
|
781 |
+
"weighing scale",
|
782 |
+
"school bus",
|
783 |
+
"schooner",
|
784 |
+
"scoreboard",
|
785 |
+
"CRT monitor",
|
786 |
+
"screw",
|
787 |
+
"screwdriver",
|
788 |
+
"seat belt",
|
789 |
+
"sewing machine",
|
790 |
+
"shield",
|
791 |
+
"shoe store",
|
792 |
+
"shoji screen / room divider",
|
793 |
+
"shopping basket",
|
794 |
+
"shopping cart",
|
795 |
+
"shovel",
|
796 |
+
"shower cap",
|
797 |
+
"shower curtain",
|
798 |
+
"ski",
|
799 |
+
"balaclava ski mask",
|
800 |
+
"sleeping bag",
|
801 |
+
"slide rule",
|
802 |
+
"sliding door",
|
803 |
+
"slot machine",
|
804 |
+
"snorkel",
|
805 |
+
"snowmobile",
|
806 |
+
"snowplow",
|
807 |
+
"soap dispenser",
|
808 |
+
"soccer ball",
|
809 |
+
"sock",
|
810 |
+
"solar thermal collector",
|
811 |
+
"sombrero",
|
812 |
+
"soup bowl",
|
813 |
+
"keyboard space bar",
|
814 |
+
"space heater",
|
815 |
+
"space shuttle",
|
816 |
+
"spatula",
|
817 |
+
"motorboat",
|
818 |
+
"spider web",
|
819 |
+
"spindle",
|
820 |
+
"sports car",
|
821 |
+
"spotlight",
|
822 |
+
"stage",
|
823 |
+
"steam locomotive",
|
824 |
+
"through arch bridge",
|
825 |
+
"steel drum",
|
826 |
+
"stethoscope",
|
827 |
+
"scarf",
|
828 |
+
"stone wall",
|
829 |
+
"stopwatch",
|
830 |
+
"stove",
|
831 |
+
"strainer",
|
832 |
+
"tram",
|
833 |
+
"stretcher",
|
834 |
+
"couch",
|
835 |
+
"stupa",
|
836 |
+
"submarine",
|
837 |
+
"suit",
|
838 |
+
"sundial",
|
839 |
+
"sunglasses",
|
840 |
+
"sunglasses",
|
841 |
+
"sunscreen",
|
842 |
+
"suspension bridge",
|
843 |
+
"mop",
|
844 |
+
"sweatshirt",
|
845 |
+
"swim trunks / shorts",
|
846 |
+
"swing",
|
847 |
+
"electrical switch",
|
848 |
+
"syringe",
|
849 |
+
"table lamp",
|
850 |
+
"tank",
|
851 |
+
"tape player",
|
852 |
+
"teapot",
|
853 |
+
"teddy bear",
|
854 |
+
"television",
|
855 |
+
"tennis ball",
|
856 |
+
"thatched roof",
|
857 |
+
"front curtain",
|
858 |
+
"thimble",
|
859 |
+
"threshing machine",
|
860 |
+
"throne",
|
861 |
+
"tile roof",
|
862 |
+
"toaster",
|
863 |
+
"tobacco shop",
|
864 |
+
"toilet seat",
|
865 |
+
"torch",
|
866 |
+
"totem pole",
|
867 |
+
"tow truck",
|
868 |
+
"toy store",
|
869 |
+
"tractor",
|
870 |
+
"semi-trailer truck",
|
871 |
+
"tray",
|
872 |
+
"trench coat",
|
873 |
+
"tricycle",
|
874 |
+
"trimaran",
|
875 |
+
"tripod",
|
876 |
+
"triumphal arch",
|
877 |
+
"trolleybus",
|
878 |
+
"trombone",
|
879 |
+
"hot tub",
|
880 |
+
"turnstile",
|
881 |
+
"typewriter keyboard",
|
882 |
+
"umbrella",
|
883 |
+
"unicycle",
|
884 |
+
"upright piano",
|
885 |
+
"vacuum cleaner",
|
886 |
+
"vase",
|
887 |
+
"vaulted or arched ceiling",
|
888 |
+
"velvet fabric",
|
889 |
+
"vending machine",
|
890 |
+
"vestment",
|
891 |
+
"viaduct",
|
892 |
+
"violin",
|
893 |
+
"volleyball",
|
894 |
+
"waffle iron",
|
895 |
+
"wall clock",
|
896 |
+
"wallet",
|
897 |
+
"wardrobe",
|
898 |
+
"military aircraft",
|
899 |
+
"sink",
|
900 |
+
"washing machine",
|
901 |
+
"water bottle",
|
902 |
+
"water jug",
|
903 |
+
"water tower",
|
904 |
+
"whiskey jug",
|
905 |
+
"whistle",
|
906 |
+
"hair wig",
|
907 |
+
"window screen",
|
908 |
+
"window shade",
|
909 |
+
"Windsor tie",
|
910 |
+
"wine bottle",
|
911 |
+
"airplane wing",
|
912 |
+
"wok",
|
913 |
+
"wooden spoon",
|
914 |
+
"wool",
|
915 |
+
"split-rail fence",
|
916 |
+
"shipwreck",
|
917 |
+
"sailboat",
|
918 |
+
"yurt",
|
919 |
+
"website",
|
920 |
+
"comic book",
|
921 |
+
"crossword",
|
922 |
+
"traffic or street sign",
|
923 |
+
"traffic light",
|
924 |
+
"dust jacket",
|
925 |
+
"menu",
|
926 |
+
"plate",
|
927 |
+
"guacamole",
|
928 |
+
"consomme",
|
929 |
+
"hot pot",
|
930 |
+
"trifle",
|
931 |
+
"ice cream",
|
932 |
+
"popsicle",
|
933 |
+
"baguette",
|
934 |
+
"bagel",
|
935 |
+
"pretzel",
|
936 |
+
"cheeseburger",
|
937 |
+
"hot dog",
|
938 |
+
"mashed potatoes",
|
939 |
+
"cabbage",
|
940 |
+
"broccoli",
|
941 |
+
"cauliflower",
|
942 |
+
"zucchini",
|
943 |
+
"spaghetti squash",
|
944 |
+
"acorn squash",
|
945 |
+
"butternut squash",
|
946 |
+
"cucumber",
|
947 |
+
"artichoke",
|
948 |
+
"bell pepper",
|
949 |
+
"cardoon",
|
950 |
+
"mushroom",
|
951 |
+
"Granny Smith apple",
|
952 |
+
"strawberry",
|
953 |
+
"orange",
|
954 |
+
"lemon",
|
955 |
+
"fig",
|
956 |
+
"pineapple",
|
957 |
+
"banana",
|
958 |
+
"jackfruit",
|
959 |
+
"cherimoya (custard apple)",
|
960 |
+
"pomegranate",
|
961 |
+
"hay",
|
962 |
+
"carbonara",
|
963 |
+
"chocolate syrup",
|
964 |
+
"dough",
|
965 |
+
"meatloaf",
|
966 |
+
"pizza",
|
967 |
+
"pot pie",
|
968 |
+
"burrito",
|
969 |
+
"red wine",
|
970 |
+
"espresso",
|
971 |
+
"tea cup",
|
972 |
+
"eggnog",
|
973 |
+
"mountain",
|
974 |
+
"bubble",
|
975 |
+
"cliff",
|
976 |
+
"coral reef",
|
977 |
+
"geyser",
|
978 |
+
"lakeshore",
|
979 |
+
"promontory",
|
980 |
+
"sandbar",
|
981 |
+
"beach",
|
982 |
+
"valley",
|
983 |
+
"volcano",
|
984 |
+
"baseball player",
|
985 |
+
"bridegroom",
|
986 |
+
"scuba diver",
|
987 |
+
"rapeseed",
|
988 |
+
"daisy",
|
989 |
+
"yellow lady's slipper",
|
990 |
+
"corn",
|
991 |
+
"acorn",
|
992 |
+
"rose hip",
|
993 |
+
"horse chestnut seed",
|
994 |
+
"coral fungus",
|
995 |
+
"agaric",
|
996 |
+
"gyromitra",
|
997 |
+
"stinkhorn mushroom",
|
998 |
+
"earth star fungus",
|
999 |
+
"hen of the woods mushroom",
|
1000 |
+
"bolete",
|
1001 |
+
"corn cob",
|
1002 |
+
"toilet paper",
|
1003 |
+
]
|
1004 |
+
# Maps numeric class ids to labels
|
1005 |
+
IMAGENET_1K_CLASS_ID_TO_LABEL = dict(
|
1006 |
+
zip(range(len(openai_imagenet_classnames)), openai_imagenet_classnames)
|
1007 |
+
)
|
multimodal/build/lib/open_flamingo/eval/ok_vqa_utils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Those are manual mapping that are not caught by our stemming rules or would
|
2 |
+
# would be done incorrectly by our automatic stemming rule. In details,
|
3 |
+
# the keys of the _MANUAL_MATCHES dict contains the original word and the value
|
4 |
+
# contains the transformation of the word expected by the OKVQA stemming rule.
|
5 |
+
# These manual rules were found by checking the `raw_answers` and the `answers`
|
6 |
+
# fields of the released OKVQA dataset and checking all things that were not
|
7 |
+
# properly mapped by our automatic rules. In particular some of the mapping
|
8 |
+
# are sometimes constant, e.g. christmas -> christmas which was incorrectly
|
9 |
+
# singularized by our inflection.singularize.
|
10 |
+
import re
|
11 |
+
import nltk
|
12 |
+
from nltk.corpus.reader import VERB
|
13 |
+
import inflection
|
14 |
+
|
15 |
+
_MANUAL_MATCHES = {
|
16 |
+
"police": "police",
|
17 |
+
"las": "las",
|
18 |
+
"vegas": "vegas",
|
19 |
+
"yes": "yes",
|
20 |
+
"jeans": "jean",
|
21 |
+
"hell's": "hell",
|
22 |
+
"domino's": "domino",
|
23 |
+
"morning": "morn",
|
24 |
+
"clothes": "cloth",
|
25 |
+
"are": "are",
|
26 |
+
"riding": "ride",
|
27 |
+
"leaves": "leaf",
|
28 |
+
"dangerous": "danger",
|
29 |
+
"clothing": "cloth",
|
30 |
+
"texting": "text",
|
31 |
+
"kiting": "kite",
|
32 |
+
"firefighters": "firefight",
|
33 |
+
"ties": "tie",
|
34 |
+
"married": "married",
|
35 |
+
"teething": "teeth",
|
36 |
+
"gloves": "glove",
|
37 |
+
"tennis": "tennis",
|
38 |
+
"dining": "dine",
|
39 |
+
"directions": "direct",
|
40 |
+
"waves": "wave",
|
41 |
+
"christmas": "christmas",
|
42 |
+
"drives": "drive",
|
43 |
+
"pudding": "pud",
|
44 |
+
"coding": "code",
|
45 |
+
"plating": "plate",
|
46 |
+
"quantas": "quanta",
|
47 |
+
"hornes": "horn",
|
48 |
+
"graves": "grave",
|
49 |
+
"mating": "mate",
|
50 |
+
"paned": "pane",
|
51 |
+
"alertness": "alert",
|
52 |
+
"sunbathing": "sunbath",
|
53 |
+
"tenning": "ten",
|
54 |
+
"wetness": "wet",
|
55 |
+
"urinating": "urine",
|
56 |
+
"sickness": "sick",
|
57 |
+
"braves": "brave",
|
58 |
+
"firefighting": "firefight",
|
59 |
+
"lenses": "lens",
|
60 |
+
"reflections": "reflect",
|
61 |
+
"backpackers": "backpack",
|
62 |
+
"eatting": "eat",
|
63 |
+
"designers": "design",
|
64 |
+
"curiousity": "curious",
|
65 |
+
"playfulness": "play",
|
66 |
+
"blindness": "blind",
|
67 |
+
"hawke": "hawk",
|
68 |
+
"tomatoe": "tomato",
|
69 |
+
"rodeoing": "rodeo",
|
70 |
+
"brightness": "bright",
|
71 |
+
"circuses": "circus",
|
72 |
+
"skateboarders": "skateboard",
|
73 |
+
"staring": "stare",
|
74 |
+
"electronics": "electron",
|
75 |
+
"electicity": "elect",
|
76 |
+
"mountainous": "mountain",
|
77 |
+
"socializing": "social",
|
78 |
+
"hamburgers": "hamburg",
|
79 |
+
"caves": "cave",
|
80 |
+
"transitions": "transit",
|
81 |
+
"wading": "wade",
|
82 |
+
"creame": "cream",
|
83 |
+
"toileting": "toilet",
|
84 |
+
"sautee": "saute",
|
85 |
+
"buildings": "build",
|
86 |
+
"belongings": "belong",
|
87 |
+
"stockings": "stock",
|
88 |
+
"walle": "wall",
|
89 |
+
"cumulis": "cumuli",
|
90 |
+
"travelers": "travel",
|
91 |
+
"conducter": "conduct",
|
92 |
+
"browsing": "brows",
|
93 |
+
"pooping": "poop",
|
94 |
+
"haircutting": "haircut",
|
95 |
+
"toppings": "top",
|
96 |
+
"hearding": "heard",
|
97 |
+
"sunblocker": "sunblock",
|
98 |
+
"bases": "base",
|
99 |
+
"markings": "mark",
|
100 |
+
"mopeds": "mope",
|
101 |
+
"kindergartener": "kindergarten",
|
102 |
+
"pies": "pie",
|
103 |
+
"scrapbooking": "scrapbook",
|
104 |
+
"couponing": "coupon",
|
105 |
+
"meetings": "meet",
|
106 |
+
"elevators": "elev",
|
107 |
+
"lowes": "low",
|
108 |
+
"men's": "men",
|
109 |
+
"childrens": "children",
|
110 |
+
"shelves": "shelve",
|
111 |
+
"paintings": "paint",
|
112 |
+
"raines": "rain",
|
113 |
+
"paring": "pare",
|
114 |
+
"expressions": "express",
|
115 |
+
"routes": "rout",
|
116 |
+
"pease": "peas",
|
117 |
+
"vastness": "vast",
|
118 |
+
"awning": "awn",
|
119 |
+
"boy's": "boy",
|
120 |
+
"drunkenness": "drunken",
|
121 |
+
"teasing": "teas",
|
122 |
+
"conferences": "confer",
|
123 |
+
"ripeness": "ripe",
|
124 |
+
"suspenders": "suspend",
|
125 |
+
"earnings": "earn",
|
126 |
+
"reporters": "report",
|
127 |
+
"kid's": "kid",
|
128 |
+
"containers": "contain",
|
129 |
+
"corgie": "corgi",
|
130 |
+
"porche": "porch",
|
131 |
+
"microwaves": "microwave",
|
132 |
+
"batter's": "batter",
|
133 |
+
"sadness": "sad",
|
134 |
+
"apartments": "apart",
|
135 |
+
"oxygenize": "oxygen",
|
136 |
+
"striping": "stripe",
|
137 |
+
"purring": "pure",
|
138 |
+
"professionals": "profession",
|
139 |
+
"piping": "pipe",
|
140 |
+
"farmer's": "farmer",
|
141 |
+
"potatoe": "potato",
|
142 |
+
"emirates": "emir",
|
143 |
+
"womens": "women",
|
144 |
+
"veteran's": "veteran",
|
145 |
+
"wilderness": "wilder",
|
146 |
+
"propellers": "propel",
|
147 |
+
"alpes": "alp",
|
148 |
+
"charioteering": "chariot",
|
149 |
+
"swining": "swine",
|
150 |
+
"illness": "ill",
|
151 |
+
"crepte": "crept",
|
152 |
+
"adhesives": "adhesive",
|
153 |
+
"regent's": "regent",
|
154 |
+
"decorations": "decor",
|
155 |
+
"rabbies": "rabbi",
|
156 |
+
"overseas": "oversea",
|
157 |
+
"travellers": "travel",
|
158 |
+
"casings": "case",
|
159 |
+
"smugness": "smug",
|
160 |
+
"doves": "dove",
|
161 |
+
"nationals": "nation",
|
162 |
+
"mustange": "mustang",
|
163 |
+
"ringe": "ring",
|
164 |
+
"gondoliere": "gondolier",
|
165 |
+
"vacationing": "vacate",
|
166 |
+
"reminders": "remind",
|
167 |
+
"baldness": "bald",
|
168 |
+
"settings": "set",
|
169 |
+
"glaced": "glace",
|
170 |
+
"coniferous": "conifer",
|
171 |
+
"revelations": "revel",
|
172 |
+
"personals": "person",
|
173 |
+
"daughter's": "daughter",
|
174 |
+
"badness": "bad",
|
175 |
+
"projections": "project",
|
176 |
+
"polarizing": "polar",
|
177 |
+
"vandalizers": "vandal",
|
178 |
+
"minerals": "miner",
|
179 |
+
"protesters": "protest",
|
180 |
+
"controllers": "control",
|
181 |
+
"weddings": "wed",
|
182 |
+
"sometimes": "sometime",
|
183 |
+
"earing": "ear",
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
class OKVQAStemmer:
|
188 |
+
"""Stemmer to match OKVQA v1.1 procedure."""
|
189 |
+
|
190 |
+
def __init__(self):
|
191 |
+
self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer()
|
192 |
+
|
193 |
+
def stem(self, input_string):
|
194 |
+
"""Apply stemming."""
|
195 |
+
word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string))
|
196 |
+
stemmed_words = []
|
197 |
+
for w, p in word_and_pos:
|
198 |
+
if w in _MANUAL_MATCHES:
|
199 |
+
w = _MANUAL_MATCHES[w]
|
200 |
+
elif w.endswith("ing"):
|
201 |
+
w = self._wordnet_lemmatizer.lemmatize(w, VERB)
|
202 |
+
elif p.startswith("NNS") or p.startswith("NNPS"):
|
203 |
+
w = inflection.singularize(w)
|
204 |
+
stemmed_words.append(w)
|
205 |
+
return " ".join(stemmed_words)
|
206 |
+
|
207 |
+
|
208 |
+
stemmer = OKVQAStemmer()
|
209 |
+
|
210 |
+
|
211 |
+
def postprocess_ok_vqa_generation(prediction) -> str:
|
212 |
+
prediction_stem = stemmer.stem(prediction)
|
213 |
+
return prediction_stem
|
multimodal/build/lib/open_flamingo/eval/task/__init__.py
ADDED
File without changes
|
multimodal/build/lib/open_flamingo/eval/task/caption.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lavis.datasets.builders import load_dataset
|
2 |
+
import torch
|
3 |
+
import more_itertools
|
4 |
+
from tqdm import tqdm
|
5 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import os
|
9 |
+
from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
class VisualLogitsProcessor(LogitsProcessor):
|
13 |
+
def __init__(self, tokenizer):
|
14 |
+
super().__init__()
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
|
17 |
+
self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
18 |
+
self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
19 |
+
self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
20 |
+
self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
21 |
+
self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
|
22 |
+
self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
23 |
+
self.topk = 2
|
24 |
+
|
25 |
+
def __call__(self, input_ids, scores):
|
26 |
+
# print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
|
27 |
+
# import pdb; pdb.set_trace()
|
28 |
+
if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
|
29 |
+
scores[0, self.object_token_id] = 1000
|
30 |
+
if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
|
31 |
+
if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
|
32 |
+
# print("generate a previsual token next")
|
33 |
+
scores[0, self.previsual_token_id] = 1000
|
34 |
+
elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
|
35 |
+
# print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
|
36 |
+
scores[0, self.eos_token_id] = 1000
|
37 |
+
elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
|
38 |
+
# print("generate a visual token next")
|
39 |
+
scores[0, self.visual_token_id] = 1000
|
40 |
+
return scores
|
41 |
+
|
42 |
+
|
43 |
+
def prepare_batch_images(batch, image_processor):
|
44 |
+
batch_images = None
|
45 |
+
for b in batch:
|
46 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
47 |
+
if batch_images is None:
|
48 |
+
batch_images = b_image
|
49 |
+
else:
|
50 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
51 |
+
return batch_images
|
52 |
+
|
53 |
+
|
54 |
+
def captioner(
|
55 |
+
model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
|
56 |
+
"""Evaluate a model on COCO dataset.
|
57 |
+
Returns:
|
58 |
+
float: CIDEr score
|
59 |
+
|
60 |
+
"""
|
61 |
+
visual_logits_processor = VisualLogitsProcessor(tokenizer)
|
62 |
+
model.eval()
|
63 |
+
# model.eval().cuda()
|
64 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
65 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
66 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
67 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
68 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
69 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
70 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
71 |
+
box_token = "<|#box#|>"
|
72 |
+
prebox_token = "<|#prebox#|>"
|
73 |
+
endofobject_token = "<|#endofobject#|>"
|
74 |
+
object_token = "<|#object#|>"
|
75 |
+
ori_prompt_length = len(input_ids[0])
|
76 |
+
have_prebox = False
|
77 |
+
out_image = None
|
78 |
+
while True:
|
79 |
+
batch_images = batch_images
|
80 |
+
input_ids = input_ids
|
81 |
+
attention_mask = attention_mask
|
82 |
+
image_start_index_list = image_start_index_list
|
83 |
+
image_nums = image_nums
|
84 |
+
if debug:
|
85 |
+
print("input--->",tokenizer.decode(input_ids[0]))
|
86 |
+
p1 = MinNewTokensLengthLogitsProcessor(
|
87 |
+
prompt_length_to_skip=input_ids.shape[-1],
|
88 |
+
min_new_tokens=5,
|
89 |
+
eos_token_id=bos_token_id,
|
90 |
+
)
|
91 |
+
with torch.inference_mode():
|
92 |
+
outputs = model.generate(
|
93 |
+
batch_images,
|
94 |
+
input_ids,
|
95 |
+
attention_mask=attention_mask,
|
96 |
+
max_new_tokens=20,
|
97 |
+
# min_new_tokens=8,
|
98 |
+
num_beams=1,
|
99 |
+
# length_penalty=0,
|
100 |
+
image_start_index_list=image_start_index_list,
|
101 |
+
image_nums=image_nums,
|
102 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
103 |
+
logits_processor_list=[p1, visual_logits_processor],
|
104 |
+
)
|
105 |
+
if debug:
|
106 |
+
print("outputs--->",tokenizer.decode(outputs[0]))
|
107 |
+
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
|
108 |
+
prompt = tokenizer.decode(outputs.clone()[0])
|
109 |
+
is_visual = (outputs[0, -2] == visual_token_id)
|
110 |
+
batch_text = tokenizer.batch_decode(outputs[:, :-1])
|
111 |
+
encodings = tokenizer(
|
112 |
+
batch_text,
|
113 |
+
padding="longest",
|
114 |
+
truncation=True,
|
115 |
+
return_tensors="pt",
|
116 |
+
max_length=2000,
|
117 |
+
)
|
118 |
+
input_ids = encodings["input_ids"]
|
119 |
+
attention_mask = encodings["attention_mask"]
|
120 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
121 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
122 |
+
image_nums = [1] * len(input_ids)
|
123 |
+
if debug:
|
124 |
+
print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
|
125 |
+
with torch.no_grad():
|
126 |
+
outputs = model(
|
127 |
+
vision_x=batch_images,
|
128 |
+
lang_x=input_ids,
|
129 |
+
attention_mask=attention_mask,
|
130 |
+
image_nums=image_nums,
|
131 |
+
image_start_index_list=image_start_index_list,
|
132 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
133 |
+
add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
|
134 |
+
)
|
135 |
+
boxes = outputs["boxes"]
|
136 |
+
scores = outputs["scores"]
|
137 |
+
# if not model.valid:
|
138 |
+
# import pdb; pdb.set_trace()
|
139 |
+
if boxes is not None:
|
140 |
+
if is_visual:
|
141 |
+
if have_prebox:
|
142 |
+
added_bbox_list.pop()
|
143 |
+
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
|
144 |
+
have_prebox = False
|
145 |
+
if debug:
|
146 |
+
print("find previsual and remove it--->", prompt)
|
147 |
+
first_box = boxes[scores.argmax()]
|
148 |
+
added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
|
149 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
150 |
+
prompt += box_token + endofobject_token
|
151 |
+
if debug:
|
152 |
+
print("after inserting visual---->", prompt)
|
153 |
+
else:
|
154 |
+
import numpy as np
|
155 |
+
import cv2
|
156 |
+
open_cv_image = np.array(image_ori)
|
157 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
158 |
+
for i, pre_box in enumerate(boxes):
|
159 |
+
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
|
160 |
+
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
|
161 |
+
# exit()
|
162 |
+
pre_box = boxes[scores.argmax()]
|
163 |
+
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
|
164 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
165 |
+
prompt += prebox_token + object_token
|
166 |
+
have_prebox = True
|
167 |
+
if debug:
|
168 |
+
print("after inserting previsual---->", prompt)
|
169 |
+
else:
|
170 |
+
if debug:
|
171 |
+
import pdb;pdb.set_trace()
|
172 |
+
prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
|
173 |
+
else:
|
174 |
+
break
|
175 |
+
outputs = outputs[:, ori_prompt_length:]
|
176 |
+
outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "")
|
177 |
+
# new_predictions = [
|
178 |
+
# postprocess_captioning_generation(out).replace('"', "")
|
179 |
+
# for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
180 |
+
# ]
|
181 |
+
# import pdb; pdb.set_trace()
|
182 |
+
return outputs, out_image
|
183 |
+
|
184 |
+
|
185 |
+
def evaluate_coco_flickr(
|
186 |
+
model,
|
187 |
+
tokenizer,
|
188 |
+
image_processor,
|
189 |
+
batch_size,
|
190 |
+
is_flickr=False,
|
191 |
+
vis_embed_size=None,
|
192 |
+
rank=0,
|
193 |
+
world_size=1,
|
194 |
+
id=0,
|
195 |
+
debug=False,
|
196 |
+
):
|
197 |
+
"""Evaluate a model on COCO dataset.
|
198 |
+
Returns:
|
199 |
+
float: CIDEr score
|
200 |
+
|
201 |
+
"""
|
202 |
+
visual_logits_processor = VisualLogitsProcessor(tokenizer)
|
203 |
+
coco_dataset = load_dataset("coco_caption")
|
204 |
+
eval_dataset = coco_dataset["test"]
|
205 |
+
model.eval().cuda()
|
206 |
+
predictions = dict()
|
207 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
208 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
209 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
210 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
211 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
212 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
213 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
214 |
+
box_token = "<|#box#|>"
|
215 |
+
prebox_token = "<|#prebox#|>"
|
216 |
+
endofobject_token = "<|#endofobject#|>"
|
217 |
+
object_token = "<|#object#|>"
|
218 |
+
cnt = 0
|
219 |
+
if world_size > 1:
|
220 |
+
torch.distributed.barrier()
|
221 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
222 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
223 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
224 |
+
)):
|
225 |
+
if ii % world_size != rank:
|
226 |
+
continue
|
227 |
+
cnt += len(batch)
|
228 |
+
batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
|
229 |
+
batch_images = prepare_batch_images(
|
230 |
+
batch=batch,
|
231 |
+
image_processor=image_processor,
|
232 |
+
).cuda()
|
233 |
+
prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
234 |
+
added_bbox_list = []
|
235 |
+
batch_text = [prompt for _ in batch]
|
236 |
+
encodings = tokenizer(
|
237 |
+
batch_text,
|
238 |
+
padding="longest",
|
239 |
+
truncation=True,
|
240 |
+
return_tensors="pt",
|
241 |
+
max_length=2000,
|
242 |
+
)
|
243 |
+
ori_prompt_length = len(encodings["input_ids"][0])
|
244 |
+
have_prebox = False
|
245 |
+
while True:
|
246 |
+
batch_text = [prompt for _ in batch]
|
247 |
+
encodings = tokenizer(
|
248 |
+
batch_text,
|
249 |
+
padding="longest",
|
250 |
+
truncation=True,
|
251 |
+
return_tensors="pt",
|
252 |
+
max_length=2000,
|
253 |
+
)
|
254 |
+
input_ids = encodings["input_ids"].cuda()
|
255 |
+
attention_mask = encodings["attention_mask"].cuda()
|
256 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
257 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
258 |
+
image_nums = [1] * len(input_ids)
|
259 |
+
if debug:
|
260 |
+
print("input--->",tokenizer.decode(input_ids[0]))
|
261 |
+
p1 = MinNewTokensLengthLogitsProcessor(
|
262 |
+
prompt_length_to_skip=input_ids.shape[-1],
|
263 |
+
min_new_tokens=5,
|
264 |
+
eos_token_id=bos_token_id,
|
265 |
+
)
|
266 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
267 |
+
outputs = model.generate(
|
268 |
+
batch_images,
|
269 |
+
input_ids,
|
270 |
+
attention_mask=attention_mask,
|
271 |
+
max_new_tokens=20,
|
272 |
+
# min_new_tokens=8,
|
273 |
+
num_beams=1,
|
274 |
+
# length_penalty=0,
|
275 |
+
image_start_index_list=image_start_index_list,
|
276 |
+
image_nums=image_nums,
|
277 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
278 |
+
logits_processor_list=[p1, visual_logits_processor],
|
279 |
+
)
|
280 |
+
if debug:
|
281 |
+
print("outputs--->",tokenizer.decode(outputs[0]))
|
282 |
+
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
|
283 |
+
prompt = tokenizer.decode(outputs.clone()[0])
|
284 |
+
is_visual = (outputs[0, -2] == visual_token_id)
|
285 |
+
batch_text = tokenizer.batch_decode(outputs[:, :-1])
|
286 |
+
encodings = tokenizer(
|
287 |
+
batch_text,
|
288 |
+
padding="longest",
|
289 |
+
truncation=True,
|
290 |
+
return_tensors="pt",
|
291 |
+
max_length=2000,
|
292 |
+
)
|
293 |
+
input_ids = encodings["input_ids"].cuda()
|
294 |
+
attention_mask = encodings["attention_mask"].cuda()
|
295 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
296 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
297 |
+
image_nums = [1] * len(input_ids)
|
298 |
+
if debug:
|
299 |
+
print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
|
300 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
301 |
+
outputs = model(
|
302 |
+
vision_x=batch_images,
|
303 |
+
lang_x=input_ids,
|
304 |
+
attention_mask=attention_mask,
|
305 |
+
image_nums=image_nums,
|
306 |
+
image_start_index_list=image_start_index_list,
|
307 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
308 |
+
add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
|
309 |
+
)
|
310 |
+
boxes = outputs["boxes"]
|
311 |
+
scores = outputs["scores"]
|
312 |
+
# if not model.valid:
|
313 |
+
# import pdb; pdb.set_trace()
|
314 |
+
if boxes is not None:
|
315 |
+
if is_visual:
|
316 |
+
if have_prebox:
|
317 |
+
added_bbox_list.pop()
|
318 |
+
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
|
319 |
+
have_prebox = False
|
320 |
+
if debug:
|
321 |
+
print("find previsual and remove it--->", prompt)
|
322 |
+
first_box = boxes[scores.argmax()]
|
323 |
+
added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
324 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
325 |
+
prompt += box_token + endofobject_token
|
326 |
+
if debug:
|
327 |
+
print("after inserting visual---->", prompt)
|
328 |
+
else:
|
329 |
+
import numpy as np
|
330 |
+
import cv2
|
331 |
+
open_cv_image = np.array(batch[0]["image"])
|
332 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
333 |
+
for i, pre_box in enumerate(boxes):
|
334 |
+
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
|
335 |
+
cv2.imwrite("Atest.png", open_cv_image)
|
336 |
+
exit()
|
337 |
+
pre_box = boxes[scores.argmax()]
|
338 |
+
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
|
339 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
340 |
+
prompt += prebox_token + object_token
|
341 |
+
have_prebox = True
|
342 |
+
if debug:
|
343 |
+
print("after inserting previsual---->", prompt)
|
344 |
+
else:
|
345 |
+
import pdb;pdb.set_trace()
|
346 |
+
prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
|
347 |
+
else:
|
348 |
+
break
|
349 |
+
outputs = outputs[:, ori_prompt_length:]
|
350 |
+
new_predictions = [
|
351 |
+
postprocess_captioning_generation(out).replace('"', "")
|
352 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
353 |
+
]
|
354 |
+
# import pdb; pdb.set_trace()
|
355 |
+
if rank == 0:
|
356 |
+
tqdm.write(new_predictions[0])
|
357 |
+
for i, sample in enumerate(batch):
|
358 |
+
predictions[int(sample["image_id"])] = {
|
359 |
+
"caption": new_predictions[i],
|
360 |
+
}
|
361 |
+
print(new_predictions)
|
362 |
+
exit()
|
363 |
+
results_path = (
|
364 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
365 |
+
if is_flickr
|
366 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
367 |
+
)
|
368 |
+
with open(results_path, "w") as f:
|
369 |
+
f.write(
|
370 |
+
json.dumps(
|
371 |
+
[
|
372 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
373 |
+
for k in predictions
|
374 |
+
],
|
375 |
+
indent=2,
|
376 |
+
)
|
377 |
+
)
|
378 |
+
print("save to", results_path)
|
379 |
+
del predictions
|
380 |
+
time.sleep(10)
|
381 |
+
if world_size > 1:
|
382 |
+
torch.distributed.barrier()
|
383 |
+
if rank == 0:
|
384 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
385 |
+
predictions = []
|
386 |
+
for rank_i in range(world_size):
|
387 |
+
part_results_path = (
|
388 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
389 |
+
if is_flickr
|
390 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
391 |
+
)
|
392 |
+
print("load", part_results_path)
|
393 |
+
predictions.extend(json.load(open(part_results_path)))
|
394 |
+
os.remove(part_results_path)
|
395 |
+
print("num:", len(predictions))
|
396 |
+
results_path = (
|
397 |
+
f"flickrresults_{lang_encoder_name}.json"
|
398 |
+
if is_flickr
|
399 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
400 |
+
)
|
401 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
402 |
+
|
403 |
+
metrics = compute_cider(
|
404 |
+
result_path=results_path,
|
405 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
406 |
+
)
|
407 |
+
metrics["CIDEr"] *= 100
|
408 |
+
os.makedirs("eval_results", exist_ok=True)
|
409 |
+
acc = metrics["CIDEr"]
|
410 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
411 |
+
f.write(json.dumps(predictions, indent=2))
|
412 |
+
|
413 |
+
# delete the temporary file
|
414 |
+
os.remove(results_path)
|
415 |
+
else:
|
416 |
+
metrics = {}
|
417 |
+
metrics["CIDEr"] = 0.0
|
418 |
+
|
419 |
+
return metrics["CIDEr"]
|
multimodal/build/lib/open_flamingo/eval/task/caption_chat.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import more_itertools
|
4 |
+
from tqdm import tqdm
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import os
|
8 |
+
from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
class VisualLogitsProcessor(LogitsProcessor):
|
12 |
+
def __init__(self, tokenizer):
|
13 |
+
super().__init__()
|
14 |
+
self.tokenizer = tokenizer
|
15 |
+
self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
|
16 |
+
self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
17 |
+
self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
18 |
+
self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
19 |
+
self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
20 |
+
self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
|
21 |
+
self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
22 |
+
self.topk = 2
|
23 |
+
|
24 |
+
def __call__(self, input_ids, scores):
|
25 |
+
# print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
|
26 |
+
# import pdb; pdb.set_trace()
|
27 |
+
if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
|
28 |
+
scores[0, self.object_token_id] = 1000
|
29 |
+
if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
|
30 |
+
if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
|
31 |
+
# print("generate a previsual token next")
|
32 |
+
scores[0, self.previsual_token_id] = 1000
|
33 |
+
elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
|
34 |
+
# print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
|
35 |
+
scores[0, self.eos_token_id] = 1000
|
36 |
+
elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
|
37 |
+
# print("generate a visual token next")
|
38 |
+
scores[0, self.visual_token_id] = 1000
|
39 |
+
return scores
|
40 |
+
|
41 |
+
|
42 |
+
def prepare_batch_images(batch, image_processor):
|
43 |
+
batch_images = None
|
44 |
+
for b in batch:
|
45 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
46 |
+
if batch_images is None:
|
47 |
+
batch_images = b_image
|
48 |
+
else:
|
49 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
50 |
+
return batch_images
|
51 |
+
|
52 |
+
|
53 |
+
def captioner(
|
54 |
+
model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
|
55 |
+
"""Evaluate a model on COCO dataset.
|
56 |
+
Returns:
|
57 |
+
float: CIDEr score
|
58 |
+
|
59 |
+
"""
|
60 |
+
visual_logits_processor = VisualLogitsProcessor(tokenizer)
|
61 |
+
model.eval()
|
62 |
+
# model.eval().cuda()
|
63 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
64 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
65 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
66 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
67 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
68 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
69 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
70 |
+
box_token = "<|#box#|>"
|
71 |
+
prebox_token = "<|#prebox#|>"
|
72 |
+
endofobject_token = "<|#endofobject#|>"
|
73 |
+
object_token = "<|#object#|>"
|
74 |
+
ori_prompt_length = len(input_ids[0])
|
75 |
+
have_prebox = False
|
76 |
+
while True:
|
77 |
+
batch_images = batch_images
|
78 |
+
input_ids = input_ids
|
79 |
+
attention_mask = attention_mask
|
80 |
+
image_start_index_list = image_start_index_list
|
81 |
+
image_nums = image_nums
|
82 |
+
if debug:
|
83 |
+
print("input--->",tokenizer.decode(input_ids[0]))
|
84 |
+
p1 = MinNewTokensLengthLogitsProcessor(
|
85 |
+
prompt_length_to_skip=input_ids.shape[-1],
|
86 |
+
min_new_tokens=5,
|
87 |
+
eos_token_id=bos_token_id,
|
88 |
+
)
|
89 |
+
with torch.inference_mode():
|
90 |
+
outputs = model.generate(
|
91 |
+
batch_images,
|
92 |
+
input_ids,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
max_new_tokens=20,
|
95 |
+
# min_new_tokens=8,
|
96 |
+
num_beams=1,
|
97 |
+
# length_penalty=0,
|
98 |
+
image_start_index_list=image_start_index_list,
|
99 |
+
image_nums=image_nums,
|
100 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
101 |
+
logits_processor_list=[p1, visual_logits_processor],
|
102 |
+
)
|
103 |
+
if debug:
|
104 |
+
print("outputs--->",tokenizer.decode(outputs[0]))
|
105 |
+
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
|
106 |
+
prompt = tokenizer.decode(outputs.clone()[0])
|
107 |
+
is_visual = (outputs[0, -2] == visual_token_id)
|
108 |
+
batch_text = tokenizer.batch_decode(outputs[:, :-1])
|
109 |
+
encodings = tokenizer(
|
110 |
+
batch_text,
|
111 |
+
padding="longest",
|
112 |
+
truncation=True,
|
113 |
+
return_tensors="pt",
|
114 |
+
max_length=2000,
|
115 |
+
)
|
116 |
+
input_ids = encodings["input_ids"]
|
117 |
+
attention_mask = encodings["attention_mask"]
|
118 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
119 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
120 |
+
image_nums = [1] * len(input_ids)
|
121 |
+
if debug:
|
122 |
+
print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
|
123 |
+
with torch.no_grad():
|
124 |
+
outputs = model(
|
125 |
+
vision_x=batch_images,
|
126 |
+
lang_x=input_ids,
|
127 |
+
attention_mask=attention_mask,
|
128 |
+
image_nums=image_nums,
|
129 |
+
image_start_index_list=image_start_index_list,
|
130 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
131 |
+
add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
|
132 |
+
)
|
133 |
+
boxes = outputs["boxes"]
|
134 |
+
scores = outputs["scores"]
|
135 |
+
# if not model.valid:
|
136 |
+
# import pdb; pdb.set_trace()
|
137 |
+
if boxes is not None:
|
138 |
+
if is_visual:
|
139 |
+
if have_prebox:
|
140 |
+
added_bbox_list.pop()
|
141 |
+
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
|
142 |
+
have_prebox = False
|
143 |
+
if debug:
|
144 |
+
print("find previsual and remove it--->", prompt)
|
145 |
+
first_box = boxes[scores.argmax()]
|
146 |
+
added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
|
147 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
148 |
+
prompt += box_token + endofobject_token
|
149 |
+
if debug:
|
150 |
+
print("after inserting visual---->", prompt)
|
151 |
+
else:
|
152 |
+
import numpy as np
|
153 |
+
import cv2
|
154 |
+
open_cv_image = np.array(image_ori)
|
155 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
156 |
+
for i, pre_box in enumerate(boxes):
|
157 |
+
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
|
158 |
+
out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
|
159 |
+
# exit()
|
160 |
+
pre_box = boxes[scores.argmax()]
|
161 |
+
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
|
162 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
163 |
+
prompt += prebox_token + object_token
|
164 |
+
have_prebox = True
|
165 |
+
if debug:
|
166 |
+
print("after inserting previsual---->", prompt)
|
167 |
+
else:
|
168 |
+
if debug:
|
169 |
+
import pdb;pdb.set_trace()
|
170 |
+
prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
|
171 |
+
else:
|
172 |
+
break
|
173 |
+
outputs = outputs[:, ori_prompt_length:]
|
174 |
+
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
|
175 |
+
# new_predictions = [
|
176 |
+
# postprocess_captioning_generation(out).replace('"', "")
|
177 |
+
# for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
178 |
+
# ]
|
179 |
+
# import pdb; pdb.set_trace()
|
180 |
+
return outputs, out_image
|
181 |
+
|
182 |
+
|
183 |
+
def evaluate_coco_flickr(
|
184 |
+
model,
|
185 |
+
tokenizer,
|
186 |
+
image_processor,
|
187 |
+
batch_size,
|
188 |
+
is_flickr=False,
|
189 |
+
vis_embed_size=None,
|
190 |
+
rank=0,
|
191 |
+
world_size=1,
|
192 |
+
id=0,
|
193 |
+
debug=False,
|
194 |
+
):
|
195 |
+
"""Evaluate a model on COCO dataset.
|
196 |
+
Returns:
|
197 |
+
float: CIDEr score
|
198 |
+
|
199 |
+
"""
|
200 |
+
visual_logits_processor = VisualLogitsProcessor(tokenizer)
|
201 |
+
coco_dataset = load_dataset("coco_caption")
|
202 |
+
eval_dataset = coco_dataset["test"]
|
203 |
+
model.eval().cuda()
|
204 |
+
predictions = dict()
|
205 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
206 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
207 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
208 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
209 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
210 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
211 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
212 |
+
box_token = "<|#box#|>"
|
213 |
+
prebox_token = "<|#prebox#|>"
|
214 |
+
endofobject_token = "<|#endofobject#|>"
|
215 |
+
object_token = "<|#object#|>"
|
216 |
+
cnt = 0
|
217 |
+
if world_size > 1:
|
218 |
+
torch.distributed.barrier()
|
219 |
+
desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
|
220 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
221 |
+
tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
|
222 |
+
)):
|
223 |
+
if ii % world_size != rank:
|
224 |
+
continue
|
225 |
+
cnt += len(batch)
|
226 |
+
batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
|
227 |
+
batch_images = prepare_batch_images(
|
228 |
+
batch=batch,
|
229 |
+
image_processor=image_processor,
|
230 |
+
).cuda()
|
231 |
+
prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
|
232 |
+
added_bbox_list = []
|
233 |
+
batch_text = [prompt for _ in batch]
|
234 |
+
encodings = tokenizer(
|
235 |
+
batch_text,
|
236 |
+
padding="longest",
|
237 |
+
truncation=True,
|
238 |
+
return_tensors="pt",
|
239 |
+
max_length=2000,
|
240 |
+
)
|
241 |
+
ori_prompt_length = len(encodings["input_ids"][0])
|
242 |
+
have_prebox = False
|
243 |
+
while True:
|
244 |
+
batch_text = [prompt for _ in batch]
|
245 |
+
encodings = tokenizer(
|
246 |
+
batch_text,
|
247 |
+
padding="longest",
|
248 |
+
truncation=True,
|
249 |
+
return_tensors="pt",
|
250 |
+
max_length=2000,
|
251 |
+
)
|
252 |
+
input_ids = encodings["input_ids"].cuda()
|
253 |
+
attention_mask = encodings["attention_mask"].cuda()
|
254 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
255 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
256 |
+
image_nums = [1] * len(input_ids)
|
257 |
+
if debug:
|
258 |
+
print("input--->",tokenizer.decode(input_ids[0]))
|
259 |
+
p1 = MinNewTokensLengthLogitsProcessor(
|
260 |
+
prompt_length_to_skip=input_ids.shape[-1],
|
261 |
+
min_new_tokens=5,
|
262 |
+
eos_token_id=bos_token_id,
|
263 |
+
)
|
264 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
265 |
+
outputs = model.generate(
|
266 |
+
batch_images,
|
267 |
+
input_ids,
|
268 |
+
attention_mask=attention_mask,
|
269 |
+
max_new_tokens=20,
|
270 |
+
# min_new_tokens=8,
|
271 |
+
num_beams=1,
|
272 |
+
# length_penalty=0,
|
273 |
+
image_start_index_list=image_start_index_list,
|
274 |
+
image_nums=image_nums,
|
275 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
276 |
+
logits_processor_list=[p1, visual_logits_processor],
|
277 |
+
)
|
278 |
+
if debug:
|
279 |
+
print("outputs--->",tokenizer.decode(outputs[0]))
|
280 |
+
if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
|
281 |
+
prompt = tokenizer.decode(outputs.clone()[0])
|
282 |
+
is_visual = (outputs[0, -2] == visual_token_id)
|
283 |
+
batch_text = tokenizer.batch_decode(outputs[:, :-1])
|
284 |
+
encodings = tokenizer(
|
285 |
+
batch_text,
|
286 |
+
padding="longest",
|
287 |
+
truncation=True,
|
288 |
+
return_tensors="pt",
|
289 |
+
max_length=2000,
|
290 |
+
)
|
291 |
+
input_ids = encodings["input_ids"].cuda()
|
292 |
+
attention_mask = encodings["attention_mask"].cuda()
|
293 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
294 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
295 |
+
image_nums = [1] * len(input_ids)
|
296 |
+
if debug:
|
297 |
+
print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
|
298 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
299 |
+
outputs = model(
|
300 |
+
vision_x=batch_images,
|
301 |
+
lang_x=input_ids,
|
302 |
+
attention_mask=attention_mask,
|
303 |
+
image_nums=image_nums,
|
304 |
+
image_start_index_list=image_start_index_list,
|
305 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
306 |
+
add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
|
307 |
+
)
|
308 |
+
boxes = outputs["boxes"]
|
309 |
+
scores = outputs["scores"]
|
310 |
+
# if not model.valid:
|
311 |
+
# import pdb; pdb.set_trace()
|
312 |
+
if boxes is not None:
|
313 |
+
if is_visual:
|
314 |
+
if have_prebox:
|
315 |
+
added_bbox_list.pop()
|
316 |
+
prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
|
317 |
+
have_prebox = False
|
318 |
+
if debug:
|
319 |
+
print("find previsual and remove it--->", prompt)
|
320 |
+
first_box = boxes[scores.argmax()]
|
321 |
+
added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
322 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
323 |
+
prompt += box_token + endofobject_token
|
324 |
+
if debug:
|
325 |
+
print("after inserting visual---->", prompt)
|
326 |
+
else:
|
327 |
+
import numpy as np
|
328 |
+
import cv2
|
329 |
+
open_cv_image = np.array(batch[0]["image"])
|
330 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
331 |
+
for i, pre_box in enumerate(boxes):
|
332 |
+
open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
|
333 |
+
cv2.imwrite("Atest.png", open_cv_image)
|
334 |
+
exit()
|
335 |
+
pre_box = boxes[scores.argmax()]
|
336 |
+
added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
|
337 |
+
prompt = prompt[:-len(tokenizer.eos_token)]
|
338 |
+
prompt += prebox_token + object_token
|
339 |
+
have_prebox = True
|
340 |
+
if debug:
|
341 |
+
print("after inserting previsual---->", prompt)
|
342 |
+
else:
|
343 |
+
import pdb;pdb.set_trace()
|
344 |
+
prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
|
345 |
+
else:
|
346 |
+
break
|
347 |
+
outputs = outputs[:, ori_prompt_length:]
|
348 |
+
new_predictions = [
|
349 |
+
postprocess_captioning_generation(out).replace('"', "")
|
350 |
+
for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
351 |
+
]
|
352 |
+
# import pdb; pdb.set_trace()
|
353 |
+
if rank == 0:
|
354 |
+
tqdm.write(new_predictions[0])
|
355 |
+
for i, sample in enumerate(batch):
|
356 |
+
predictions[int(sample["image_id"])] = {
|
357 |
+
"caption": new_predictions[i],
|
358 |
+
}
|
359 |
+
print(new_predictions)
|
360 |
+
exit()
|
361 |
+
results_path = (
|
362 |
+
f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
|
363 |
+
if is_flickr
|
364 |
+
else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
|
365 |
+
)
|
366 |
+
with open(results_path, "w") as f:
|
367 |
+
f.write(
|
368 |
+
json.dumps(
|
369 |
+
[
|
370 |
+
{"image_id": k, "caption": predictions[k]["caption"]}
|
371 |
+
for k in predictions
|
372 |
+
],
|
373 |
+
indent=2,
|
374 |
+
)
|
375 |
+
)
|
376 |
+
print("save to", results_path)
|
377 |
+
del predictions
|
378 |
+
time.sleep(10)
|
379 |
+
if world_size > 1:
|
380 |
+
torch.distributed.barrier()
|
381 |
+
if rank == 0:
|
382 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
383 |
+
predictions = []
|
384 |
+
for rank_i in range(world_size):
|
385 |
+
part_results_path = (
|
386 |
+
f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
387 |
+
if is_flickr
|
388 |
+
else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
|
389 |
+
)
|
390 |
+
print("load", part_results_path)
|
391 |
+
predictions.extend(json.load(open(part_results_path)))
|
392 |
+
os.remove(part_results_path)
|
393 |
+
print("num:", len(predictions))
|
394 |
+
results_path = (
|
395 |
+
f"flickrresults_{lang_encoder_name}.json"
|
396 |
+
if is_flickr
|
397 |
+
else f"cocoresults_{lang_encoder_name}.json"
|
398 |
+
)
|
399 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
400 |
+
|
401 |
+
metrics = compute_cider(
|
402 |
+
result_path=results_path,
|
403 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
|
404 |
+
)
|
405 |
+
metrics["CIDEr"] *= 100
|
406 |
+
os.makedirs("eval_results", exist_ok=True)
|
407 |
+
acc = metrics["CIDEr"]
|
408 |
+
with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
409 |
+
f.write(json.dumps(predictions, indent=2))
|
410 |
+
|
411 |
+
# delete the temporary file
|
412 |
+
os.remove(results_path)
|
413 |
+
else:
|
414 |
+
metrics = {}
|
415 |
+
metrics["CIDEr"] = 0.0
|
416 |
+
|
417 |
+
return metrics["CIDEr"]
|
multimodal/build/lib/open_flamingo/eval/task/cola.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import webdataset as wds
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import cv2
|
10 |
+
import random
|
11 |
+
import math
|
12 |
+
from open_flamingo.eval.task.utils import (
|
13 |
+
get_object_from_text,
|
14 |
+
is_correct,
|
15 |
+
_eval_text_image,
|
16 |
+
get_bbox,
|
17 |
+
get_iou,
|
18 |
+
)
|
19 |
+
DATASET = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/COLA/data/COLA_multiobjects_matching_benchmark.json"
|
20 |
+
VG_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K"
|
21 |
+
|
22 |
+
def get_score(image, text, model, tokenizer, image_processor, vis_embed_size):
|
23 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
24 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
25 |
+
object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
|
26 |
+
text = text.split("#")
|
27 |
+
obj_A = text[0].strip().split(" ")
|
28 |
+
relation = text[1].strip()
|
29 |
+
obj_B = text[2].strip().split(" ")
|
30 |
+
if "computer mouse" not in text[0].strip():
|
31 |
+
attrAs = obj_A[:-1]
|
32 |
+
nounA = obj_A[-1]
|
33 |
+
else:
|
34 |
+
attrAs = obj_A[:-2]
|
35 |
+
nounA = " ".join(obj_A[-2:])
|
36 |
+
if "computer mouse" not in text[2].strip():
|
37 |
+
attrBs = obj_B[:-1]
|
38 |
+
nounB = obj_B[-1]
|
39 |
+
else:
|
40 |
+
attrBs = obj_B[:-2]
|
41 |
+
nounB = " ".join(obj_B[-2:])
|
42 |
+
# print("="*80)
|
43 |
+
# print(attrAs, nounA)
|
44 |
+
# print(attrBs, nounB)
|
45 |
+
# print(relation)
|
46 |
+
# print("="*80)
|
47 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
48 |
+
|
49 |
+
|
50 |
+
prompt1 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {nounA}<|#endofobject#|><|#visual#|>"]
|
51 |
+
boxes, scores = get_bbox(None, batch_images, prompt1, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
|
52 |
+
|
53 |
+
|
54 |
+
# open_cv_image = np.array(image)
|
55 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
56 |
+
# for pre_box in boxes:
|
57 |
+
# open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
|
58 |
+
|
59 |
+
box_ppl = []
|
60 |
+
box_attr_losses = []
|
61 |
+
for box in boxes:
|
62 |
+
losses = []
|
63 |
+
for attrA in attrAs:
|
64 |
+
prompt2 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {attrA} {nounA}"]
|
65 |
+
encodings = tokenizer(
|
66 |
+
prompt2,
|
67 |
+
padding="longest",
|
68 |
+
truncation=True,
|
69 |
+
return_tensors="pt",
|
70 |
+
max_length=512,
|
71 |
+
)
|
72 |
+
input_ids = encodings["input_ids"]
|
73 |
+
attention_mask = encodings["attention_mask"]
|
74 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
75 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
76 |
+
image_nums = [1] * len(input_ids)
|
77 |
+
vision_x = batch_images.cuda()
|
78 |
+
lang_x = input_ids.cuda()
|
79 |
+
attention_mask = attention_mask.cuda()
|
80 |
+
labels = lang_x.clone()
|
81 |
+
start_idx = (labels == object_token_id).nonzero()[-1, -1]
|
82 |
+
labels[0, :start_idx+1] = -100
|
83 |
+
added_bbox_list = [torch.tensor(box / 224.0).cuda().unsqueeze(0)]
|
84 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
85 |
+
outputs = model(
|
86 |
+
vision_x=vision_x,
|
87 |
+
lang_x=lang_x,
|
88 |
+
attention_mask=attention_mask,
|
89 |
+
labels=labels,
|
90 |
+
image_nums=image_nums,
|
91 |
+
image_start_index_list=image_start_index_list,
|
92 |
+
added_bbox_list=added_bbox_list,
|
93 |
+
add_box=added_bbox_list is not None,
|
94 |
+
relations=None,
|
95 |
+
)
|
96 |
+
loss = outputs.loss
|
97 |
+
loss = (loss.sum() / (loss != 0).sum()).item()
|
98 |
+
losses.append(loss)
|
99 |
+
avg_ppl = np.array(losses).mean()
|
100 |
+
box_ppl.append(avg_ppl)
|
101 |
+
box_attr_losses.append(losses)
|
102 |
+
fit_idx = np.array(box_ppl).argmin()
|
103 |
+
fit_box = boxes[fit_idx]
|
104 |
+
fit_attr = attrAs[np.array(box_attr_losses[fit_idx]).argmin()]
|
105 |
+
first_ppl = min(box_ppl)
|
106 |
+
|
107 |
+
# open_cv_image = cv2.rectangle(open_cv_image, fit_box[:2].astype(int), fit_box[2:].astype(int), (255, 0, 0), 2)
|
108 |
+
|
109 |
+
|
110 |
+
prompt3 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {fit_attr} {nounA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> is {relation}<|#object#|><|#previsual#|>"]
|
111 |
+
boxes, scores = get_bbox([torch.tensor(fit_box / 224).cuda().unsqueeze(0)], batch_images, prompt3, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
|
112 |
+
# for i, pre_box in enumerate(boxes):
|
113 |
+
# open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 0, 255), i+1)
|
114 |
+
# cv2.imwrite(f"Atest.png", open_cv_image)
|
115 |
+
|
116 |
+
box_ppl = []
|
117 |
+
for box in boxes:
|
118 |
+
losses = []
|
119 |
+
for attrB in attrBs:
|
120 |
+
prompt4 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {fit_attr} {nounA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {attrB} {nounB}"]
|
121 |
+
encodings = tokenizer(
|
122 |
+
prompt4,
|
123 |
+
padding="longest",
|
124 |
+
truncation=True,
|
125 |
+
return_tensors="pt",
|
126 |
+
max_length=512,
|
127 |
+
)
|
128 |
+
input_ids = encodings["input_ids"]
|
129 |
+
attention_mask = encodings["attention_mask"]
|
130 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
131 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
132 |
+
image_nums = [1] * len(input_ids)
|
133 |
+
vision_x = batch_images.cuda()
|
134 |
+
lang_x = input_ids.cuda()
|
135 |
+
attention_mask = attention_mask.cuda()
|
136 |
+
labels = lang_x.clone()
|
137 |
+
start_idx = (labels == object_token_id).nonzero()[-1, -1]
|
138 |
+
labels[0, :start_idx+1] = -100
|
139 |
+
added_bbox_list = [torch.tensor(fit_box / 224.0).cuda().unsqueeze(0), torch.tensor(box / 224.0).cuda().unsqueeze(0)]
|
140 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
141 |
+
outputs = model(
|
142 |
+
vision_x=vision_x,
|
143 |
+
lang_x=lang_x,
|
144 |
+
attention_mask=attention_mask,
|
145 |
+
labels=labels,
|
146 |
+
image_nums=image_nums,
|
147 |
+
image_start_index_list=image_start_index_list,
|
148 |
+
added_bbox_list=added_bbox_list,
|
149 |
+
add_box=added_bbox_list is not None,
|
150 |
+
relations=None,
|
151 |
+
)
|
152 |
+
loss = outputs.loss
|
153 |
+
loss = (loss.sum() / (loss != 0).sum()).item()
|
154 |
+
losses.append(loss)
|
155 |
+
avg_ppl = np.array(losses).mean()
|
156 |
+
box_ppl.append(avg_ppl)
|
157 |
+
second_ppl = (np.array(box_ppl) * np.array(scores)).sum() / sum(scores)
|
158 |
+
return (first_ppl + second_ppl) / 2
|
159 |
+
|
160 |
+
|
161 |
+
def evaluate_cola(
|
162 |
+
model,
|
163 |
+
tokenizer,
|
164 |
+
image_processor,
|
165 |
+
vis_embed_size=None,
|
166 |
+
rank=0,
|
167 |
+
world_size=1,
|
168 |
+
id=0,
|
169 |
+
debug=False,
|
170 |
+
):
|
171 |
+
dataset_name = "cola"
|
172 |
+
dataset = json.load(open(DATASET))
|
173 |
+
model = model.cuda().eval()
|
174 |
+
correct = 0
|
175 |
+
total = 0
|
176 |
+
pbar = tqdm(dataset, disable=(rank != 0))
|
177 |
+
for ii, sample in enumerate(pbar):
|
178 |
+
if ii % world_size != rank:
|
179 |
+
continue
|
180 |
+
image1 = Image.open(os.path.join(VG_ROOT, os.path.basename(sample[0]))).convert("RGB").resize((224, 224))
|
181 |
+
text1 = sample[1]
|
182 |
+
image2 = Image.open(os.path.join(VG_ROOT, os.path.basename(sample[2]))).convert("RGB").resize((224, 224))
|
183 |
+
text2 = sample[3]
|
184 |
+
score11 = -get_score(image1, text1, model, tokenizer, image_processor, vis_embed_size)
|
185 |
+
score12 = -get_score(image1, text2, model, tokenizer, image_processor, vis_embed_size)
|
186 |
+
score21 = -get_score(image2, text1, model, tokenizer, image_processor, vis_embed_size)
|
187 |
+
score22 = -get_score(image2, text2, model, tokenizer, image_processor, vis_embed_size)
|
188 |
+
if rank == 0:
|
189 |
+
tqdm.write(f"{score11:.2f} {score12:.2f} {score21:.2f} {score22:.2f}")
|
190 |
+
if score11 > score21 and score22 > score12:
|
191 |
+
correct += 1
|
192 |
+
total += 1
|
193 |
+
pbar.set_description(f"{correct / total:.2f}")
|
194 |
+
print(rank, correct / total)
|
195 |
+
|
196 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
197 |
+
f.write(json.dumps([total, correct]))
|
198 |
+
if world_size > 1:
|
199 |
+
torch.distributed.barrier()
|
200 |
+
if rank == 0:
|
201 |
+
total = 0
|
202 |
+
correct = 0
|
203 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
204 |
+
for rank_i in range(world_size):
|
205 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
206 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
207 |
+
total += total_part
|
208 |
+
correct += correct_part
|
209 |
+
score = correct / total
|
210 |
+
print("score:", score)
|
211 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}_{total}"), "w") as f:
|
212 |
+
pass
|
213 |
+
else:
|
214 |
+
score = 0.0
|
215 |
+
if world_size > 1:
|
216 |
+
torch.distributed.barrier()
|
217 |
+
return score
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
evaluate_cola(None, None, None)
|
multimodal/build/lib/open_flamingo/eval/task/crepe.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import webdataset as wds
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import cv2
|
10 |
+
import random
|
11 |
+
import pandas as pd
|
12 |
+
from .vl_checklist import _eval_text_image
|
13 |
+
DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/crepe/prod_hard_negatives"
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate_crepe(
|
17 |
+
model,
|
18 |
+
tokenizer,
|
19 |
+
image_processor,
|
20 |
+
vis_embed_size=None,
|
21 |
+
rank=0,
|
22 |
+
world_size=1,
|
23 |
+
id=0,
|
24 |
+
subset=True,
|
25 |
+
debug=False,
|
26 |
+
level=4,
|
27 |
+
type="swap",
|
28 |
+
):
|
29 |
+
if rank == 0:
|
30 |
+
tqdm.write(f"level: {level}")
|
31 |
+
tqdm.write(f"type: {type}")
|
32 |
+
dataset_name = "crepe"
|
33 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
34 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
35 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
36 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
37 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
38 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
39 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
40 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
41 |
+
model.eval().cuda()
|
42 |
+
total = 0
|
43 |
+
correct = 0
|
44 |
+
assert type in ["swap"]
|
45 |
+
assert 4 <= level <= 12
|
46 |
+
filename = os.path.join(DATASET_ROOT, type, f"prod_vg_hard_negs_{type}_complexity_{level}.csv")
|
47 |
+
df = pd.read_csv(filename)
|
48 |
+
pbar = tqdm(df.iterrows(), disable=(rank != 0))
|
49 |
+
for ii, sample in pbar:
|
50 |
+
if ii % world_size != rank:
|
51 |
+
continue
|
52 |
+
text = sample.caption
|
53 |
+
image_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K/{}.jpg".format(sample.image_id)
|
54 |
+
x = sample.x
|
55 |
+
y = sample.y
|
56 |
+
width = sample.width
|
57 |
+
height = sample.height
|
58 |
+
image = Image.open(image_path).convert("RGB")
|
59 |
+
image = image.crop((x, y, x+width, y+height))
|
60 |
+
image = image.resize((224, 224))
|
61 |
+
final_rank, final_ranks = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
|
62 |
+
if final_rank is None:
|
63 |
+
continue
|
64 |
+
correct += int((np.array(final_ranks) < 10).sum())
|
65 |
+
total += len(final_ranks)
|
66 |
+
if debug:
|
67 |
+
tqdm.write("="*80)
|
68 |
+
pbar.set_description(f"{text} | score: {correct / total:.4f} | {final_rank} | {final_ranks}")
|
69 |
+
|
70 |
+
|
71 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
72 |
+
f.write(json.dumps([total, correct]))
|
73 |
+
if world_size > 1:
|
74 |
+
torch.distributed.barrier()
|
75 |
+
if rank == 0:
|
76 |
+
total = 0
|
77 |
+
correct = 0
|
78 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
79 |
+
for rank_i in range(world_size):
|
80 |
+
[total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
81 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
82 |
+
total += total_part
|
83 |
+
correct += correct_part
|
84 |
+
score = correct / total
|
85 |
+
print("score:", score, "total:", total)
|
86 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
|
87 |
+
pass
|
88 |
+
else:
|
89 |
+
score = 0.0
|
90 |
+
if world_size > 1:
|
91 |
+
torch.distributed.barrier()
|
92 |
+
return score
|
93 |
+
|
multimodal/build/lib/open_flamingo/eval/task/gqa.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import more_itertools
|
7 |
+
from tqdm import tqdm
|
8 |
+
import time
|
9 |
+
from vqa_metric import compute_gqa_accuracy
|
10 |
+
import string
|
11 |
+
import uuid
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
from open_flamingo.eval.task.utils import get_bbox
|
15 |
+
|
16 |
+
class GQADataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
image_dir_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/images",
|
20 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/testdev_balanced_questions.json",
|
21 |
+
):
|
22 |
+
annotations = json.load(open(annotations_path))
|
23 |
+
self.questions = []
|
24 |
+
self.answers = []
|
25 |
+
self.image_paths = []
|
26 |
+
self.question_ids = []
|
27 |
+
for anno_id in annotations:
|
28 |
+
question = annotations[anno_id]["question"]
|
29 |
+
imageId = annotations[anno_id]["imageId"]
|
30 |
+
answer = annotations[anno_id]["answer"]
|
31 |
+
self.questions.append(question)
|
32 |
+
self.answers.append(answer)
|
33 |
+
self.image_paths.append(os.path.join(image_dir_path, "{}.jpg".format(imageId)))
|
34 |
+
self.question_ids.append(anno_id)
|
35 |
+
# print(annotations[anno_id]["types"])
|
36 |
+
self.vqa_dataset = "gqa"
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.questions)
|
40 |
+
|
41 |
+
def __getitem__(self, idx):
|
42 |
+
question = self.questions[idx]
|
43 |
+
question_id = self.question_ids[idx]
|
44 |
+
answer = self.answers[idx]
|
45 |
+
img_path = self.image_paths[idx]
|
46 |
+
image = Image.open(img_path)
|
47 |
+
return {
|
48 |
+
"image": image,
|
49 |
+
"question": question,
|
50 |
+
"answers": answer,
|
51 |
+
"question_id": question_id,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def prepare_batch_images(batch, image_processor):
|
56 |
+
batch_images = None
|
57 |
+
for b in batch:
|
58 |
+
b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
59 |
+
if batch_images is None:
|
60 |
+
batch_images = b_image
|
61 |
+
else:
|
62 |
+
batch_images = torch.cat([batch_images, b_image], dim=0)
|
63 |
+
return batch_images
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def evaluate_gqa(
|
68 |
+
model,
|
69 |
+
tokenizer,
|
70 |
+
image_processor,
|
71 |
+
batch_size=1,
|
72 |
+
vis_embed_size=None,
|
73 |
+
rank=0,
|
74 |
+
world_size=1,
|
75 |
+
id=0,
|
76 |
+
):
|
77 |
+
"""
|
78 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
model (nn.Module): model to evaluate
|
82 |
+
tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
|
83 |
+
image_processor : image processor for the model
|
84 |
+
batch_size (int): batch size
|
85 |
+
image_dir_path (str): path to image directory
|
86 |
+
questions_json_path (str): path to questions json file
|
87 |
+
annotations_json_path (str): path to annotations json file
|
88 |
+
seed (int, optional): random seed. Defaults to 42.
|
89 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
90 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
91 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
92 |
+
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
|
93 |
+
query_set_size (int, optional): size of the query set. Defaults to 2048.
|
94 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
95 |
+
device (int, optional): device to use. Defaults to -1 (cpu).
|
96 |
+
num_workers (int, optional): number of workers to use. Defaults to 4.
|
97 |
+
vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
|
98 |
+
Returns:
|
99 |
+
float: accuracy score
|
100 |
+
"""
|
101 |
+
assert batch_size == 1
|
102 |
+
vqa_dataset = "gqa"
|
103 |
+
eval_dataset = GQADataset()
|
104 |
+
object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
|
105 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
106 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
107 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
108 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
109 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
110 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
111 |
+
def get_prompt(sample):
|
112 |
+
return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
|
113 |
+
model.eval().cuda()
|
114 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
115 |
+
predictions = []
|
116 |
+
if batch_size != 1:
|
117 |
+
tokenizer.padding_side = "left"
|
118 |
+
if world_size > 1:
|
119 |
+
torch.distributed.barrier()
|
120 |
+
this_tot = 0
|
121 |
+
for ii, batch in enumerate(more_itertools.chunked(
|
122 |
+
tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size,
|
123 |
+
)):
|
124 |
+
if ii % world_size != rank:
|
125 |
+
continue
|
126 |
+
batch[0]["image"] = batch[0]["image"].resize((224, 224))
|
127 |
+
batch_images = prepare_batch_images(
|
128 |
+
batch=batch,
|
129 |
+
image_processor=image_processor,
|
130 |
+
).cuda()
|
131 |
+
batch_text = [get_prompt(s) for s in batch]
|
132 |
+
encodings = tokenizer(
|
133 |
+
batch_text,
|
134 |
+
return_tensors="pt",
|
135 |
+
padding="longest",
|
136 |
+
truncation=True,
|
137 |
+
max_length=2000,
|
138 |
+
)
|
139 |
+
input_ids = encodings["input_ids"].cuda()
|
140 |
+
attention_mask = encodings["attention_mask"].cuda()
|
141 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
142 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
143 |
+
image_nums = [1] * len(input_ids)
|
144 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
145 |
+
outputs = model.generate(
|
146 |
+
batch_images,
|
147 |
+
input_ids,
|
148 |
+
attention_mask=attention_mask,
|
149 |
+
max_new_tokens=10,
|
150 |
+
min_length=1,
|
151 |
+
num_beams=1,
|
152 |
+
# length_penalty=0,
|
153 |
+
image_start_index_list=image_start_index_list,
|
154 |
+
image_nums=image_nums,
|
155 |
+
added_bbox_list=None,
|
156 |
+
return_dict_in_generate=True,
|
157 |
+
output_scores=True,
|
158 |
+
)
|
159 |
+
scores = outputs.scores
|
160 |
+
outputs = outputs.sequences[:, len(input_ids[0]) :]
|
161 |
+
if object_token_id in scores[0][0].sort(descending=True).indices[:5]:
|
162 |
+
sample = batch[0]
|
163 |
+
# print("="*80)
|
164 |
+
# print("sample:", batch, scores[0][0].sort(descending=True).indices[:10].tolist().index(object_token_id))
|
165 |
+
prompt1 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:<|#object#|><|#previsual#|>"]
|
166 |
+
boxes, scores = get_bbox(None, batch_images, prompt1, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
|
167 |
+
# open_cv_image = np.array(sample["image"])
|
168 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
169 |
+
# cv2.imwrite(f"Atest_ori.png", open_cv_image)
|
170 |
+
# open_cv_image = cv2.rectangle(open_cv_image, boxes[0][:2].astype(int), boxes[0][2:].astype(int), (0, 255, 0), 2)
|
171 |
+
# print(scores)
|
172 |
+
# cv2.imwrite(f"Atest.png", open_cv_image)
|
173 |
+
if boxes is not None and len(boxes) > 0:
|
174 |
+
prompt2 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer: it is<|#object#|><|#previsual#|><|#prebox#|><|#object#|> a"]
|
175 |
+
encodings = tokenizer(
|
176 |
+
prompt2,
|
177 |
+
return_tensors="pt",
|
178 |
+
padding="longest",
|
179 |
+
truncation=True,
|
180 |
+
max_length=2000,
|
181 |
+
)
|
182 |
+
input_ids = encodings["input_ids"].cuda()
|
183 |
+
attention_mask = encodings["attention_mask"].cuda()
|
184 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
185 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
186 |
+
image_nums = [1] * len(input_ids)
|
187 |
+
added_bbox_list = [torch.tensor(boxes[0]/224.0).cuda().unsqueeze(0).clamp(0, 0.99)]
|
188 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
189 |
+
outputs = model.generate(
|
190 |
+
batch_images,
|
191 |
+
input_ids,
|
192 |
+
attention_mask=attention_mask,
|
193 |
+
max_new_tokens=10,
|
194 |
+
min_length=1,
|
195 |
+
num_beams=1,
|
196 |
+
image_start_index_list=image_start_index_list,
|
197 |
+
image_nums=image_nums,
|
198 |
+
added_bbox_list=added_bbox_list,
|
199 |
+
eos_token_id=(endofobject_token_id),
|
200 |
+
)
|
201 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
202 |
+
# print("previsual===>{}".format(tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower().strip(string.punctuation+" ")))
|
203 |
+
|
204 |
+
# postprocess begin
|
205 |
+
new_predictions = [
|
206 |
+
out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
207 |
+
]
|
208 |
+
this_tot += 1
|
209 |
+
predictions.extend(
|
210 |
+
[
|
211 |
+
{"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
|
212 |
+
for p, sample in zip(new_predictions, batch)
|
213 |
+
]
|
214 |
+
)
|
215 |
+
with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
|
216 |
+
f.write(json.dumps(predictions))
|
217 |
+
print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
|
218 |
+
|
219 |
+
time.sleep(10)
|
220 |
+
if world_size > 1:
|
221 |
+
torch.distributed.barrier()
|
222 |
+
if rank == 0:
|
223 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
224 |
+
predictions = []
|
225 |
+
for rank_i in range(world_size):
|
226 |
+
print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
227 |
+
predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
|
228 |
+
os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
|
229 |
+
print("num:", len(predictions))
|
230 |
+
# save the predictions to a temporary file
|
231 |
+
random_uuid = str(uuid.uuid4())
|
232 |
+
with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
|
233 |
+
f.write(json.dumps(predictions, indent=4))
|
234 |
+
|
235 |
+
acc = compute_gqa_accuracy(predictions)
|
236 |
+
print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
|
237 |
+
os.makedirs("eval_results", exist_ok=True)
|
238 |
+
with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
|
239 |
+
f.write(json.dumps(predictions, indent=2))
|
240 |
+
|
241 |
+
# delete the temporary file
|
242 |
+
os.remove(f"{vqa_dataset}results_{random_uuid}.json")
|
243 |
+
else:
|
244 |
+
time.sleep(5)
|
245 |
+
acc = 0.0
|
246 |
+
if world_size > 1:
|
247 |
+
torch.distributed.barrier()
|
248 |
+
return acc
|
multimodal/build/lib/open_flamingo/eval/task/mmbench.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import random
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from open_flamingo.eval.task.utils import get_object_from_text
|
9 |
+
|
10 |
+
def decode_base64_to_image(base64_string):
|
11 |
+
image_data = base64.b64decode(base64_string)
|
12 |
+
image = Image.open(io.BytesIO(image_data))
|
13 |
+
return image
|
14 |
+
|
15 |
+
class MMBenchDataset(Dataset):
|
16 |
+
def __init__(self,
|
17 |
+
data_file,
|
18 |
+
sys_prompt='There are several options:'):
|
19 |
+
self.df = pd.read_csv(data_file, sep='\t')
|
20 |
+
self.sys_prompt = sys_prompt
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.df)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
index = self.df.iloc[idx]['index']
|
27 |
+
image = self.df.iloc[idx]['image']
|
28 |
+
image = decode_base64_to_image(image)
|
29 |
+
question = self.df.iloc[idx]['question']
|
30 |
+
answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None
|
31 |
+
catetory = self.df.iloc[idx]['category']
|
32 |
+
l2_catetory = self.df.iloc[idx]['l2-category']
|
33 |
+
|
34 |
+
option_candidate = ['A', 'B', 'C', 'D', 'E']
|
35 |
+
options = {
|
36 |
+
cand: self.load_from_df(idx, cand)
|
37 |
+
for cand in option_candidate
|
38 |
+
if self.load_from_df(idx, cand) is not None
|
39 |
+
}
|
40 |
+
options_prompt = f'{self.sys_prompt}\n'
|
41 |
+
for key, item in options.items():
|
42 |
+
options_prompt += f'{key}. {item}\n'
|
43 |
+
|
44 |
+
hint = self.load_from_df(idx, 'hint')
|
45 |
+
data = {
|
46 |
+
'img': image,
|
47 |
+
'question': question,
|
48 |
+
'answer': answer,
|
49 |
+
'options': options_prompt,
|
50 |
+
'category': catetory,
|
51 |
+
'l2-category': l2_catetory,
|
52 |
+
'options_dict': options,
|
53 |
+
'index': index,
|
54 |
+
'context': hint,
|
55 |
+
}
|
56 |
+
return data
|
57 |
+
def load_from_df(self, idx, key):
|
58 |
+
if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
|
59 |
+
return self.df.iloc[idx][key]
|
60 |
+
else:
|
61 |
+
return None
|
62 |
+
|
63 |
+
|
64 |
+
def evaluate_mmbench(
|
65 |
+
model,
|
66 |
+
tokenizer,
|
67 |
+
image_processor,
|
68 |
+
batch_size=1,
|
69 |
+
image_dir_path=None,
|
70 |
+
questions_json_path=None,
|
71 |
+
annotations_json_path=None,
|
72 |
+
vis_embed_size=None,
|
73 |
+
rank=0,
|
74 |
+
world_size=1,
|
75 |
+
id=0,
|
76 |
+
):
|
77 |
+
dataset_name = "mmbench"
|
78 |
+
dataset = MMBenchDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/mmbench/mmbench_dev_20230712.tsv")
|
79 |
+
for sample in dataset:
|
80 |
+
print(sample)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
evaluate_mmbench(None, None, None)
|
multimodal/build/lib/open_flamingo/eval/task/reg.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
+
import numpy as np
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
from coco_metric import compute_cider
|
12 |
+
import random
|
13 |
+
import pickle
|
14 |
+
|
15 |
+
def evaluate_reg(
|
16 |
+
model,
|
17 |
+
tokenizer,
|
18 |
+
image_processor,
|
19 |
+
vis_embed_size=None,
|
20 |
+
rank=0,
|
21 |
+
world_size=1,
|
22 |
+
id=0,
|
23 |
+
):
|
24 |
+
lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
|
25 |
+
dataset_name = "refcocog"
|
26 |
+
pkl_file = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_data.pkl"
|
27 |
+
try:
|
28 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
29 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
30 |
+
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
|
31 |
+
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
|
32 |
+
except:
|
33 |
+
pass
|
34 |
+
|
35 |
+
model.eval().cuda()
|
36 |
+
if world_size > 1:
|
37 |
+
torch.distributed.barrier()
|
38 |
+
this_tot = 0
|
39 |
+
predictions = []
|
40 |
+
D = pickle.load(open(pkl_file, "rb"))
|
41 |
+
lines = []
|
42 |
+
data = D["data"]
|
43 |
+
uniq_id_to_text = D["uniq_id_to_text"]
|
44 |
+
uniq_id_to_image = D["uniq_id_to_image"]
|
45 |
+
uniq_id_to_image_id = D["uniq_id_to_image_id"]
|
46 |
+
for image_id in data:
|
47 |
+
for region in data[image_id]:
|
48 |
+
uniq_id = data[image_id][region][0]
|
49 |
+
lines.append([uniq_id, uniq_id_to_image_id[uniq_id], [uniq_id_to_text[r] for r in data[image_id][region]], region, uniq_id_to_image[uniq_id]])
|
50 |
+
print("total data:", len(lines))
|
51 |
+
# lines = lines[:20]
|
52 |
+
pbar = tqdm(lines, disable=(rank != 0))
|
53 |
+
for ii, line in enumerate(pbar):
|
54 |
+
if ii % world_size != rank:
|
55 |
+
continue
|
56 |
+
uniq_id, image_id, text, region_coord, image = line
|
57 |
+
gt_box = np.array(region_coord)
|
58 |
+
width = image.width
|
59 |
+
height = image.height
|
60 |
+
image = image.resize((224, 224))
|
61 |
+
gt_box = gt_box / np.array([width, height, width, height]) * 224
|
62 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
63 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|>"]
|
64 |
+
|
65 |
+
encodings = tokenizer(
|
66 |
+
prompt,
|
67 |
+
padding="longest",
|
68 |
+
truncation=True,
|
69 |
+
return_tensors="pt",
|
70 |
+
max_length=2000,
|
71 |
+
)
|
72 |
+
input_ids = encodings["input_ids"]
|
73 |
+
attention_mask = encodings["attention_mask"]
|
74 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
75 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
76 |
+
image_nums = [1] * len(input_ids)
|
77 |
+
batch_images = batch_images.cuda()
|
78 |
+
input_ids = input_ids.cuda()
|
79 |
+
attention_mask = attention_mask.cuda()
|
80 |
+
added_bbox_list = [(torch.tensor(gt_box).cuda() / 224).clamp(0, 0.99).unsqueeze(0)]
|
81 |
+
|
82 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
83 |
+
outputs = model.generate(
|
84 |
+
batch_images,
|
85 |
+
input_ids,
|
86 |
+
attention_mask=attention_mask,
|
87 |
+
max_new_tokens=25,
|
88 |
+
min_length=5,
|
89 |
+
num_beams=8,
|
90 |
+
length_penalty=0,
|
91 |
+
image_start_index_list=image_start_index_list,
|
92 |
+
image_nums=image_nums,
|
93 |
+
added_bbox_list=added_bbox_list,
|
94 |
+
)
|
95 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
96 |
+
new_prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower()
|
97 |
+
this_tot += 1
|
98 |
+
if rank == 0 and this_tot % 10 == 0:
|
99 |
+
for i in range(1):
|
100 |
+
tqdm.write(f"answer: {text}\nmodel output: {new_prediction}")
|
101 |
+
predictions.append(
|
102 |
+
{"image_id": image_id, "caption": new_prediction}
|
103 |
+
)
|
104 |
+
results_path = f"reg_{lang_encoder_name}_{rank}_{id}.json"
|
105 |
+
json.dump(predictions, open(results_path, "w"))
|
106 |
+
print("save to", results_path)
|
107 |
+
del predictions
|
108 |
+
time.sleep(5)
|
109 |
+
if world_size > 1:
|
110 |
+
torch.distributed.barrier()
|
111 |
+
if rank == 0:
|
112 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
113 |
+
predictions = []
|
114 |
+
for rank_i in range(world_size):
|
115 |
+
part_results_path = f"reg_{lang_encoder_name}_{rank_i}_{id}.json"
|
116 |
+
print("load", part_results_path)
|
117 |
+
part_data = json.load(open(part_results_path))
|
118 |
+
predictions.extend(part_data)
|
119 |
+
os.remove(part_results_path)
|
120 |
+
print("num:", len(predictions))
|
121 |
+
results_path = f"reg_{lang_encoder_name}_{id}_result.json"
|
122 |
+
json.dump(predictions, open(results_path, "w"), indent=2)
|
123 |
+
|
124 |
+
metrics = compute_cider(
|
125 |
+
result_path=results_path,
|
126 |
+
annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_label.json",
|
127 |
+
)
|
128 |
+
os.makedirs("eval_results", exist_ok=True)
|
129 |
+
cider = metrics["CIDEr"]
|
130 |
+
print("cider", cider)
|
131 |
+
with open(os.path.join("eval_results", f"reg_{model.expr_name}_{model.step_num}_{int(time.time())}_{cider}"), "w") as f:
|
132 |
+
f.write(json.dumps(predictions, indent=2))
|
133 |
+
# delete the temporary file
|
134 |
+
os.remove(results_path)
|
135 |
+
return cider
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
anno = json.load(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json"))
|
140 |
+
import pdb; pdb.set_trace()
|
141 |
+
print(anno.keys())
|
multimodal/build/lib/open_flamingo/eval/task/utils.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spacy
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import itertools
|
6 |
+
nlp = spacy.load('en_core_web_md')
|
7 |
+
|
8 |
+
|
9 |
+
def get_iou(box1, box2):
|
10 |
+
# box1 and box2 should be in the format [x1, y1, x2, y2]
|
11 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
12 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
13 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
14 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
15 |
+
union = area_box1 + area_box2 - intersection
|
16 |
+
iou = intersection / union if union > 0 else 0
|
17 |
+
return iou
|
18 |
+
|
19 |
+
|
20 |
+
# def find_root(token):
|
21 |
+
# if token.pos_ == "VERB":
|
22 |
+
# return token
|
23 |
+
# while token.dep_ not in ["pobj", "nsubj", "ROOT", "npadvmod", "dobj", "det", "prep", "punct", "cc", "conj", "acl", "dep", "appos", "relcl", "advmod", "nmod", "attr"]:
|
24 |
+
# token = token.head
|
25 |
+
# return token
|
26 |
+
|
27 |
+
|
28 |
+
def find_root(token):
|
29 |
+
if token.pos_ == "VERB":
|
30 |
+
return token
|
31 |
+
while token.dep_ in ["compound", "amod"]:
|
32 |
+
token = token.head
|
33 |
+
return token
|
34 |
+
|
35 |
+
def get_object_from_text(text, verbose=False):
|
36 |
+
if len(text.split(" ")) == 3:
|
37 |
+
text = text.split(" ")
|
38 |
+
return [text[0], text[-1]]
|
39 |
+
doc = nlp(text)
|
40 |
+
if verbose:
|
41 |
+
for TT in doc:
|
42 |
+
print(TT.text, TT.pos_, TT.dep_, TT.head)
|
43 |
+
roots = set()
|
44 |
+
for i, token in enumerate(doc):
|
45 |
+
roots.add(find_root(token))
|
46 |
+
exprs = []
|
47 |
+
roots = sorted(list(roots), key=lambda token: token.idx)
|
48 |
+
first_nsubj = True
|
49 |
+
if verbose:
|
50 |
+
print(roots)
|
51 |
+
for root in roots:
|
52 |
+
if root.pos_ not in ["NOUN", "PROPN"]:
|
53 |
+
continue
|
54 |
+
if root.dep_ not in ["pobj", "nsubj"]:
|
55 |
+
continue
|
56 |
+
if not first_nsubj and root.dep_ in ["nsubj"]:
|
57 |
+
continue
|
58 |
+
exprs.append([])
|
59 |
+
for token in doc:
|
60 |
+
if find_root(token) == root:
|
61 |
+
exprs[-1].append(token.text)
|
62 |
+
exprs[-1] = " ".join(exprs[-1]).replace(" '", "'")
|
63 |
+
if exprs[-1] not in text:
|
64 |
+
if verbose:
|
65 |
+
print("not in text error:", exprs[-1], "#",text)
|
66 |
+
# for TT in doc:
|
67 |
+
# print(TT.text, TT.pos_, TT.dep_, TT.head)
|
68 |
+
# import pdb; pdb.set_trace()
|
69 |
+
exprs.pop()
|
70 |
+
if first_nsubj and root.dep_ in ["nsubj"]:
|
71 |
+
first_nsubj = False
|
72 |
+
if len(exprs) <= 1:
|
73 |
+
if verbose:
|
74 |
+
print("not enough exprs error:", exprs, "#",text)
|
75 |
+
return []
|
76 |
+
return exprs
|
77 |
+
|
78 |
+
def is_correct(input_ids, logits, tokenizer, object: str, topk=5, N=10):
|
79 |
+
answer_id = torch.tensor(tokenizer(f" {object}", add_special_tokens=False)["input_ids"]).to(input_ids.device)
|
80 |
+
answer_begin_idx = (input_ids == answer_id[0]).nonzero()
|
81 |
+
answer_idx = None
|
82 |
+
for (batch_idx, IDX) in answer_begin_idx:
|
83 |
+
try:
|
84 |
+
if (input_ids[batch_idx, IDX:IDX+len(answer_id)] == answer_id).all():
|
85 |
+
answer_idx = list(range(IDX-1, IDX+len(answer_id)-1))
|
86 |
+
except:
|
87 |
+
pass
|
88 |
+
if answer_idx is None:
|
89 |
+
return np.inf, False, False
|
90 |
+
res = logits[0, answer_idx].softmax(-1).sort(descending=True)
|
91 |
+
values = res.values
|
92 |
+
indices = res.indices
|
93 |
+
chosen_ids = list(itertools.product(*([list(range(N))]*len(answer_idx))))
|
94 |
+
probs = []
|
95 |
+
for ids in chosen_ids:
|
96 |
+
prob = 1.0
|
97 |
+
for i, id in enumerate(ids):
|
98 |
+
prob *= values[i, id]
|
99 |
+
probs.append((prob.item(), ids))
|
100 |
+
probs.sort(reverse=True)
|
101 |
+
answer_pos = tuple([id_array.tolist().index(idx) for id_array, idx in zip(indices, answer_id)])
|
102 |
+
ranking = [p[1] for p in probs]
|
103 |
+
# if len(answer_idx) > 1:
|
104 |
+
# import pdb; pdb.set_trace()
|
105 |
+
try:
|
106 |
+
r = ranking.index(answer_pos)
|
107 |
+
return r, r < 1, r < 5
|
108 |
+
except:
|
109 |
+
return np.inf, False, False
|
110 |
+
|
111 |
+
def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
|
112 |
+
assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
|
113 |
+
encodings = tokenizer(
|
114 |
+
prompt,
|
115 |
+
padding="longest",
|
116 |
+
truncation=True,
|
117 |
+
return_tensors="pt",
|
118 |
+
max_length=2000,
|
119 |
+
)
|
120 |
+
input_ids = encodings["input_ids"]
|
121 |
+
attention_mask = encodings["attention_mask"]
|
122 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
123 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
124 |
+
image_nums = [1] * len(input_ids)
|
125 |
+
vision_x = batch_images.cuda()
|
126 |
+
lang_x = input_ids.cuda()
|
127 |
+
attention_mask = attention_mask.cuda()
|
128 |
+
|
129 |
+
model.debug_id = 0
|
130 |
+
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
|
131 |
+
outputs = model(
|
132 |
+
vision_x=vision_x,
|
133 |
+
lang_x=lang_x,
|
134 |
+
attention_mask=attention_mask,
|
135 |
+
labels=None,
|
136 |
+
image_nums=image_nums,
|
137 |
+
image_start_index_list=image_start_index_list,
|
138 |
+
added_bbox_list=visual_box_list,
|
139 |
+
add_box=visual_box_list is not None,
|
140 |
+
relations=None,
|
141 |
+
debug_mode=False,
|
142 |
+
)
|
143 |
+
boxes = outputs["boxes"]
|
144 |
+
scores = outputs["scores"]
|
145 |
+
if debug:
|
146 |
+
import pdb; pdb.set_trace()
|
147 |
+
if return_all:
|
148 |
+
return boxes, scores
|
149 |
+
if len(scores) == 0:
|
150 |
+
return None, None
|
151 |
+
else:
|
152 |
+
return boxes[scores.argmax()], scores.max()
|
153 |
+
|
154 |
+
|
155 |
+
def _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=False, objects=None):
|
156 |
+
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
157 |
+
if objects is None:
|
158 |
+
objects = get_object_from_text(text)
|
159 |
+
if len(objects) == 0:
|
160 |
+
return None, None, None
|
161 |
+
if debug:
|
162 |
+
tqdm.write(text)
|
163 |
+
tqdm.write(f"{objects}")
|
164 |
+
first_idx = text.find(objects[0])
|
165 |
+
if first_idx == 0:
|
166 |
+
first_text = f"<|#object#|>{objects[0]}<|#endofobject#|><|#visual#|>"
|
167 |
+
else:
|
168 |
+
first_text = text[:first_idx-1] + f"<|#object#|> {objects[0]}<|#endofobject#|><|#visual#|>"
|
169 |
+
|
170 |
+
if debug:
|
171 |
+
tqdm.write(first_text)
|
172 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
|
173 |
+
# import pdb; pdb.set_trace()
|
174 |
+
# print("do first get_bbox |", first_text)
|
175 |
+
first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
|
176 |
+
if not model.valid and debug:
|
177 |
+
import pdb; pdb.set_trace()
|
178 |
+
if first_box is not None:
|
179 |
+
added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
|
180 |
+
text = first_text + "<|#box#|><|#endofobject#|>" + text[first_idx+len(objects[0]):]
|
181 |
+
else:
|
182 |
+
added_bbox_list = []
|
183 |
+
|
184 |
+
final_ranks = []
|
185 |
+
is_top1_list = []
|
186 |
+
is_top5_list = []
|
187 |
+
for kk, object in enumerate(objects):
|
188 |
+
if kk == 0:
|
189 |
+
continue
|
190 |
+
idx = text.find(objects[0])
|
191 |
+
for t_i, temp in enumerate(objects[1:kk+1]):
|
192 |
+
# t_i is actually the previous one. This is not a bug
|
193 |
+
idx = text.find(temp, idx + len(objects[t_i]))
|
194 |
+
while idx+len(temp) != len(text) and (text[idx-1] == "#" or text[idx+len(temp)] == "#"):
|
195 |
+
# in case temp is box or object or visual or something like that
|
196 |
+
idx = text.find(temp, idx + len(temp))
|
197 |
+
this_text = text[:idx-1] + "<|#object#|><|#previsual#|>"
|
198 |
+
# if this_text == "<|#object#|><|#previsual#|>":
|
199 |
+
# import pdb; pdb.set_trace()
|
200 |
+
if debug:
|
201 |
+
tqdm.write(this_text)
|
202 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
|
203 |
+
# import pdb; pdb.set_trace()
|
204 |
+
# print("do pre get_bbox |", this_text)
|
205 |
+
pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id,
|
206 |
+
prebox_token_id, return_all=True)
|
207 |
+
if not model.valid and debug:
|
208 |
+
import pdb; pdb.set_trace()
|
209 |
+
logits_list = []
|
210 |
+
# pre_boxes = [pre_boxes[0]]
|
211 |
+
# pre_scores = [pre_scores[0]]
|
212 |
+
this_text = this_text + f"<|#prebox#|><|#object#|> {object}<|#endofobject#|>"
|
213 |
+
for pre_box, pre_score in zip(pre_boxes, pre_scores):
|
214 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
|
215 |
+
encodings = tokenizer(
|
216 |
+
prompt,
|
217 |
+
padding="longest",
|
218 |
+
truncation=True,
|
219 |
+
return_tensors="pt",
|
220 |
+
max_length=512,
|
221 |
+
)
|
222 |
+
input_ids = encodings["input_ids"]
|
223 |
+
attention_mask = encodings["attention_mask"]
|
224 |
+
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
225 |
+
image_start_index_list = [[x] for x in image_start_index_list]
|
226 |
+
image_nums = [1] * len(input_ids)
|
227 |
+
vision_x = batch_images.cuda()
|
228 |
+
lang_x = input_ids.cuda()
|
229 |
+
attention_mask = attention_mask.cuda()
|
230 |
+
this_added_bbox_list = added_bbox_list + [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
|
231 |
+
|
232 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
|
233 |
+
outputs = model(
|
234 |
+
vision_x=vision_x,
|
235 |
+
lang_x=lang_x,
|
236 |
+
attention_mask=attention_mask,
|
237 |
+
image_nums=image_nums,
|
238 |
+
image_start_index_list=image_start_index_list,
|
239 |
+
added_bbox_list=this_added_bbox_list,
|
240 |
+
add_box=this_added_bbox_list is not None and len(this_added_bbox_list) != 0,
|
241 |
+
relations=None,
|
242 |
+
)
|
243 |
+
if not model.valid and debug:
|
244 |
+
import pdb; pdb.set_trace()
|
245 |
+
logits_list.append([pre_score, outputs.logits])
|
246 |
+
if debug:
|
247 |
+
answer_start_idx = (lang_x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1]
|
248 |
+
logits = outputs["logits"][0, answer_start_idx:]
|
249 |
+
tqdm.write(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10]))
|
250 |
+
# if debug:
|
251 |
+
# image.save("Atest.png")
|
252 |
+
# open_cv_image = np.array(image)
|
253 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
254 |
+
# if first_box is not None:
|
255 |
+
# open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
|
256 |
+
# if pre_box is not None:
|
257 |
+
# open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
|
258 |
+
# cv2.imwrite(f"Atest.png", open_cv_image)
|
259 |
+
# import pdb; pdb.set_trace()
|
260 |
+
pre_scores = np.array([x[0] for x in logits_list])
|
261 |
+
final_probs = 0.0
|
262 |
+
for score, (_, logits) in zip(pre_scores, logits_list):
|
263 |
+
final_probs += score * logits.softmax(-1)
|
264 |
+
assert input_ids.shape[:2] == final_probs.shape[:2]
|
265 |
+
_rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, object, topk=5)
|
266 |
+
final_ranks.append(_rank)
|
267 |
+
is_top1_list.append(is_top1)
|
268 |
+
is_top5_list.append(is_top5)
|
269 |
+
this_text = text[:idx-1] + f"<|#object#|> {object}<|#endofobject#|><|#visual#|>"
|
270 |
+
if debug:
|
271 |
+
tqdm.write(this_text)
|
272 |
+
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
|
273 |
+
# print("do this get_bbox |", this_text)
|
274 |
+
this_box, this_score = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
|
275 |
+
if not model.valid and debug:
|
276 |
+
import pdb; pdb.set_trace()
|
277 |
+
if this_box is not None:
|
278 |
+
added_bbox_list += [torch.tensor(this_box).unsqueeze(0).cuda() / 224]
|
279 |
+
text = this_text + "<|#box#|><|#endofobject#|>" + text[idx+len(object):]
|
280 |
+
return final_ranks, is_top1_list, is_top5_list
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
# print(get_object_from_text("there is a cookie. there is a bear. white orio cookie is next to the teddy bear. car runs on the traffic road. there is a tree.", verbose=False))
|
287 |
+
print(get_object_from_text("President speaks to an American at a business office",verbose=True))
|
multimodal/build/lib/open_flamingo/eval/task/vl_checklist.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import webdataset as wds
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import cv2
|
10 |
+
import random
|
11 |
+
from open_flamingo.eval.task.utils import (
|
12 |
+
get_object_from_text,
|
13 |
+
is_correct,
|
14 |
+
_eval_text_image,
|
15 |
+
)
|
16 |
+
DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/instruct/vl_checklist/Relation/000000.tar"
|
17 |
+
|
18 |
+
def evaluate_vlc(
|
19 |
+
model,
|
20 |
+
tokenizer,
|
21 |
+
image_processor,
|
22 |
+
vis_embed_size=None,
|
23 |
+
rank=0,
|
24 |
+
world_size=1,
|
25 |
+
id=0,
|
26 |
+
subset=True,
|
27 |
+
subset_size="5k",
|
28 |
+
debug=False,
|
29 |
+
):
|
30 |
+
dataset_name = "vlc"
|
31 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
32 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
33 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
34 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
35 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
36 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
37 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
38 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
39 |
+
model.eval().cuda()
|
40 |
+
total = 0
|
41 |
+
n_top1 = 0
|
42 |
+
n_top5 = 0
|
43 |
+
n_top10 = 0
|
44 |
+
filename = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data.json" if not subset else f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data_subset_{subset_size}.json"
|
45 |
+
dataset = json.load(open(filename))
|
46 |
+
|
47 |
+
pbar = tqdm(dataset, disable=(rank != 0))
|
48 |
+
for ii, sample in enumerate(pbar):
|
49 |
+
if ii % world_size != rank:
|
50 |
+
continue
|
51 |
+
text, image_path = sample
|
52 |
+
image = Image.open(image_path).convert("RGB")
|
53 |
+
image = image.resize((224, 224))
|
54 |
+
final_ranks, is_top1_list, is_top5_list = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
|
55 |
+
if final_ranks is None:
|
56 |
+
continue
|
57 |
+
n_top1 += int(sum(is_top1_list))
|
58 |
+
n_top5 += int(sum(is_top5_list))
|
59 |
+
n_top10 += int((np.array(final_ranks) < 10).sum())
|
60 |
+
total += len(final_ranks)
|
61 |
+
if debug:
|
62 |
+
tqdm.write("="*80)
|
63 |
+
pbar.set_description(f"acc@top1: {n_top1 / total:.4f} | acc@top5: {n_top5 / total:.4f} | acc@top10: {n_top10 / total:.4f} | {final_ranks} |{text}")
|
64 |
+
|
65 |
+
|
66 |
+
with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
|
67 |
+
f.write(json.dumps([total, n_top1, n_top5, n_top10]))
|
68 |
+
if world_size > 1:
|
69 |
+
torch.distributed.barrier()
|
70 |
+
if rank == 0:
|
71 |
+
total = 0
|
72 |
+
n_top1 = 0
|
73 |
+
n_top5 = 0
|
74 |
+
n_top10 = 0
|
75 |
+
print(f"evaluate on rank {rank}. world size is {world_size}")
|
76 |
+
for rank_i in range(world_size):
|
77 |
+
[total_part, n_top1_part, n_top5_part, n_top10_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
|
78 |
+
os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
|
79 |
+
total += total_part
|
80 |
+
n_top1 += n_top1_part
|
81 |
+
n_top5 += n_top5_part
|
82 |
+
n_top10 += n_top10_part
|
83 |
+
print("acc@top1:", n_top1 / total, "acc@top5:", n_top5 / total, "acc@top10:", n_top10 / total, "total:", total)
|
84 |
+
with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{n_top1 / total}_{n_top5 / total}_{n_top10 / total}_{total}"), "w") as f:
|
85 |
+
pass
|
86 |
+
else:
|
87 |
+
score = 0.0
|
88 |
+
if world_size > 1:
|
89 |
+
torch.distributed.barrier()
|
90 |
+
return score
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
dataset = wds.WebDataset(DATASET_ROOT).decode().shuffle(100000).to_tuple("data.pyd", "dataset.txt", "image_path.txt")
|
95 |
+
labels = set()
|
96 |
+
texts = []
|
97 |
+
data_pair = []
|
98 |
+
if not os.path.exists("vlc_data.json"):
|
99 |
+
for sample in tqdm(dataset):
|
100 |
+
data, dataset_name, image_path = sample
|
101 |
+
text = data[-1]["POS"][0]
|
102 |
+
texts.append(text)
|
103 |
+
data_pair.append([text, image_path])
|
104 |
+
json.dump(data_pair, open("vlc_data.json", "w"), indent=1)
|
105 |
+
else:
|
106 |
+
print("data exists")
|
107 |
+
data_pair = json.load(open("vlc_data.json"))
|
108 |
+
for text, image_path in data_pair:
|
109 |
+
texts.append(text)
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
print(get_object_from_text("crow attacks the dove"))
|
multimodal/build/lib/open_flamingo/eval/vqa_metric.py
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# Interface for accessing the VQA dataset.
|
10 |
+
|
11 |
+
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
12 |
+
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
13 |
+
|
14 |
+
# The following functions are defined:
|
15 |
+
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
16 |
+
# getQuesIds - Get question ids that satisfy given filter conditions.
|
17 |
+
# getImgIds - Get image ids that satisfy given filter conditions.
|
18 |
+
# loadQA - Load questions and answers with the specified question ids.
|
19 |
+
# showQA - Display the specified questions and answers.
|
20 |
+
# loadRes - Load result file and create result object.
|
21 |
+
|
22 |
+
# Help on each function can be accessed by: "help(COCO.function)"
|
23 |
+
|
24 |
+
|
25 |
+
class VQA:
|
26 |
+
def __init__(self, annotation_file=None, question_file=None):
|
27 |
+
"""
|
28 |
+
Constructor of VQA helper class for reading and visualizing questions and answers.
|
29 |
+
:param annotation_file (str): location of VQA annotation file
|
30 |
+
:return:
|
31 |
+
"""
|
32 |
+
# load dataset
|
33 |
+
self.dataset = {}
|
34 |
+
self.questions = {}
|
35 |
+
self.qa = {}
|
36 |
+
self.qqa = {}
|
37 |
+
self.imgToQA = {}
|
38 |
+
if not annotation_file == None and not question_file == None:
|
39 |
+
print("loading VQA annotations and questions into memory...")
|
40 |
+
time_t = datetime.datetime.utcnow()
|
41 |
+
dataset = json.load(open(annotation_file, "r"))
|
42 |
+
questions = json.load(open(question_file, "r"))
|
43 |
+
print(datetime.datetime.utcnow() - time_t)
|
44 |
+
self.dataset = dataset
|
45 |
+
self.questions = questions
|
46 |
+
self.createIndex()
|
47 |
+
|
48 |
+
def createIndex(self):
|
49 |
+
# create index
|
50 |
+
print("creating index...")
|
51 |
+
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
|
52 |
+
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
53 |
+
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
54 |
+
for ann in self.dataset["annotations"]:
|
55 |
+
imgToQA[ann["image_id"]] += [ann]
|
56 |
+
qa[ann["question_id"]] = ann
|
57 |
+
for ques in self.questions["questions"]:
|
58 |
+
qqa[ques["question_id"]] = ques
|
59 |
+
print("index created!")
|
60 |
+
|
61 |
+
# create class members
|
62 |
+
self.qa = qa
|
63 |
+
self.qqa = qqa
|
64 |
+
self.imgToQA = imgToQA
|
65 |
+
|
66 |
+
def info(self):
|
67 |
+
"""
|
68 |
+
Print information about the VQA annotation file.
|
69 |
+
:return:
|
70 |
+
"""
|
71 |
+
for key, value in self.dataset["info"].items():
|
72 |
+
print("%s: %s" % (key, value))
|
73 |
+
|
74 |
+
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
75 |
+
"""
|
76 |
+
Get question ids that satisfy given filter conditions. default skips that filter
|
77 |
+
:param imgIds (int array) : get question ids for given imgs
|
78 |
+
quesTypes (str array) : get question ids for given question types
|
79 |
+
ansTypes (str array) : get question ids for given answer types
|
80 |
+
:return: ids (int array) : integer array of question ids
|
81 |
+
"""
|
82 |
+
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
83 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
84 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
85 |
+
|
86 |
+
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
87 |
+
anns = self.dataset["annotations"]
|
88 |
+
else:
|
89 |
+
if not len(imgIds) == 0:
|
90 |
+
anns = sum(
|
91 |
+
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
|
92 |
+
[],
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
anns = self.dataset["annotations"]
|
96 |
+
anns = (
|
97 |
+
anns
|
98 |
+
if len(quesTypes) == 0
|
99 |
+
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
100 |
+
)
|
101 |
+
anns = (
|
102 |
+
anns
|
103 |
+
if len(ansTypes) == 0
|
104 |
+
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
105 |
+
)
|
106 |
+
ids = [ann["question_id"] for ann in anns]
|
107 |
+
return ids
|
108 |
+
|
109 |
+
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
110 |
+
"""
|
111 |
+
Get image ids that satisfy given filter conditions. default skips that filter
|
112 |
+
:param quesIds (int array) : get image ids for given question ids
|
113 |
+
quesTypes (str array) : get image ids for given question types
|
114 |
+
ansTypes (str array) : get image ids for given answer types
|
115 |
+
:return: ids (int array) : integer array of image ids
|
116 |
+
"""
|
117 |
+
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
118 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
119 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
120 |
+
|
121 |
+
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
122 |
+
anns = self.dataset["annotations"]
|
123 |
+
else:
|
124 |
+
if not len(quesIds) == 0:
|
125 |
+
anns = sum(
|
126 |
+
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
anns = self.dataset["annotations"]
|
130 |
+
anns = (
|
131 |
+
anns
|
132 |
+
if len(quesTypes) == 0
|
133 |
+
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
134 |
+
)
|
135 |
+
anns = (
|
136 |
+
anns
|
137 |
+
if len(ansTypes) == 0
|
138 |
+
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
139 |
+
)
|
140 |
+
ids = [ann["image_id"] for ann in anns]
|
141 |
+
return ids
|
142 |
+
|
143 |
+
def loadQA(self, ids=[]):
|
144 |
+
"""
|
145 |
+
Load questions and answers with the specified question ids.
|
146 |
+
:param ids (int array) : integer ids specifying question ids
|
147 |
+
:return: qa (object array) : loaded qa objects
|
148 |
+
"""
|
149 |
+
if type(ids) == list:
|
150 |
+
return [self.qa[id] for id in ids]
|
151 |
+
elif type(ids) == int:
|
152 |
+
return [self.qa[ids]]
|
153 |
+
|
154 |
+
def showQA(self, anns):
|
155 |
+
"""
|
156 |
+
Display the specified annotations.
|
157 |
+
:param anns (array of object): annotations to display
|
158 |
+
:return: None
|
159 |
+
"""
|
160 |
+
if len(anns) == 0:
|
161 |
+
return 0
|
162 |
+
for ann in anns:
|
163 |
+
quesId = ann["question_id"]
|
164 |
+
print("Question: %s" % (self.qqa[quesId]["question"]))
|
165 |
+
for ans in ann["answers"]:
|
166 |
+
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
|
167 |
+
|
168 |
+
def loadRes(self, resFile, quesFile):
|
169 |
+
"""
|
170 |
+
Load result file and return a result object.
|
171 |
+
:param resFile (str) : file name of result file
|
172 |
+
:return: res (obj) : result api object
|
173 |
+
"""
|
174 |
+
res = VQA()
|
175 |
+
res.questions = json.load(open(quesFile))
|
176 |
+
res.dataset["info"] = copy.deepcopy(self.questions["info"])
|
177 |
+
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
|
178 |
+
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
|
179 |
+
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
|
180 |
+
res.dataset["license"] = copy.deepcopy(self.questions["license"])
|
181 |
+
|
182 |
+
print("Loading and preparing results... ")
|
183 |
+
time_t = datetime.datetime.utcnow()
|
184 |
+
anns = json.load(open(resFile))
|
185 |
+
assert type(anns) == list, "results is not an array of objects"
|
186 |
+
annsQuesIds = [ann["question_id"] for ann in anns]
|
187 |
+
# print set of question ids that do not have corresponding annotations
|
188 |
+
|
189 |
+
# assert set(annsQuesIds) == set(self.getQuesIds()), \
|
190 |
+
# 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
|
191 |
+
for ann in anns:
|
192 |
+
quesId = ann["question_id"]
|
193 |
+
if res.dataset["task_type"] == "Multiple Choice":
|
194 |
+
assert (
|
195 |
+
ann["answer"] in self.qqa[quesId]["multiple_choices"]
|
196 |
+
), "predicted answer is not one of the multiple choices"
|
197 |
+
qaAnn = self.qa[quesId]
|
198 |
+
ann["image_id"] = qaAnn["image_id"]
|
199 |
+
ann["question_type"] = qaAnn["question_type"]
|
200 |
+
ann["answer_type"] = qaAnn["answer_type"]
|
201 |
+
print(
|
202 |
+
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
|
203 |
+
)
|
204 |
+
|
205 |
+
res.dataset["annotations"] = anns
|
206 |
+
res.createIndex()
|
207 |
+
return res
|
208 |
+
|
209 |
+
|
210 |
+
class VQAEval:
|
211 |
+
def __init__(self, vqa=None, vqaRes=None, n=2):
|
212 |
+
self.n = n
|
213 |
+
self.accuracy = {}
|
214 |
+
self.evalQA = {}
|
215 |
+
self.evalQuesType = {}
|
216 |
+
self.evalAnsType = {}
|
217 |
+
self.vqa = vqa
|
218 |
+
self.vqaRes = vqaRes
|
219 |
+
if vqaRes is not None:
|
220 |
+
self.params = {"question_id": vqaRes.getQuesIds()}
|
221 |
+
self.contractions = {
|
222 |
+
"aint": "ain't",
|
223 |
+
"arent": "aren't",
|
224 |
+
"cant": "can't",
|
225 |
+
"couldve": "could've",
|
226 |
+
"couldnt": "couldn't",
|
227 |
+
"couldn'tve": "couldn't've",
|
228 |
+
"couldnt've": "couldn't've",
|
229 |
+
"didnt": "didn't",
|
230 |
+
"doesnt": "doesn't",
|
231 |
+
"dont": "don't",
|
232 |
+
"hadnt": "hadn't",
|
233 |
+
"hadnt've": "hadn't've",
|
234 |
+
"hadn'tve": "hadn't've",
|
235 |
+
"hasnt": "hasn't",
|
236 |
+
"havent": "haven't",
|
237 |
+
"hed": "he'd",
|
238 |
+
"hed've": "he'd've",
|
239 |
+
"he'dve": "he'd've",
|
240 |
+
"hes": "he's",
|
241 |
+
"howd": "how'd",
|
242 |
+
"howll": "how'll",
|
243 |
+
"hows": "how's",
|
244 |
+
"Id've": "I'd've",
|
245 |
+
"I'dve": "I'd've",
|
246 |
+
"Im": "I'm",
|
247 |
+
"Ive": "I've",
|
248 |
+
"isnt": "isn't",
|
249 |
+
"itd": "it'd",
|
250 |
+
"itd've": "it'd've",
|
251 |
+
"it'dve": "it'd've",
|
252 |
+
"itll": "it'll",
|
253 |
+
"let's": "let's",
|
254 |
+
"maam": "ma'am",
|
255 |
+
"mightnt": "mightn't",
|
256 |
+
"mightnt've": "mightn't've",
|
257 |
+
"mightn'tve": "mightn't've",
|
258 |
+
"mightve": "might've",
|
259 |
+
"mustnt": "mustn't",
|
260 |
+
"mustve": "must've",
|
261 |
+
"neednt": "needn't",
|
262 |
+
"notve": "not've",
|
263 |
+
"oclock": "o'clock",
|
264 |
+
"oughtnt": "oughtn't",
|
265 |
+
"ow's'at": "'ow's'at",
|
266 |
+
"'ows'at": "'ow's'at",
|
267 |
+
"'ow'sat": "'ow's'at",
|
268 |
+
"shant": "shan't",
|
269 |
+
"shed've": "she'd've",
|
270 |
+
"she'dve": "she'd've",
|
271 |
+
"she's": "she's",
|
272 |
+
"shouldve": "should've",
|
273 |
+
"shouldnt": "shouldn't",
|
274 |
+
"shouldnt've": "shouldn't've",
|
275 |
+
"shouldn'tve": "shouldn't've",
|
276 |
+
"somebody'd": "somebodyd",
|
277 |
+
"somebodyd've": "somebody'd've",
|
278 |
+
"somebody'dve": "somebody'd've",
|
279 |
+
"somebodyll": "somebody'll",
|
280 |
+
"somebodys": "somebody's",
|
281 |
+
"someoned": "someone'd",
|
282 |
+
"someoned've": "someone'd've",
|
283 |
+
"someone'dve": "someone'd've",
|
284 |
+
"someonell": "someone'll",
|
285 |
+
"someones": "someone's",
|
286 |
+
"somethingd": "something'd",
|
287 |
+
"somethingd've": "something'd've",
|
288 |
+
"something'dve": "something'd've",
|
289 |
+
"somethingll": "something'll",
|
290 |
+
"thats": "that's",
|
291 |
+
"thered": "there'd",
|
292 |
+
"thered've": "there'd've",
|
293 |
+
"there'dve": "there'd've",
|
294 |
+
"therere": "there're",
|
295 |
+
"theres": "there's",
|
296 |
+
"theyd": "they'd",
|
297 |
+
"theyd've": "they'd've",
|
298 |
+
"they'dve": "they'd've",
|
299 |
+
"theyll": "they'll",
|
300 |
+
"theyre": "they're",
|
301 |
+
"theyve": "they've",
|
302 |
+
"twas": "'twas",
|
303 |
+
"wasnt": "wasn't",
|
304 |
+
"wed've": "we'd've",
|
305 |
+
"we'dve": "we'd've",
|
306 |
+
"weve": "we've",
|
307 |
+
"werent": "weren't",
|
308 |
+
"whatll": "what'll",
|
309 |
+
"whatre": "what're",
|
310 |
+
"whats": "what's",
|
311 |
+
"whatve": "what've",
|
312 |
+
"whens": "when's",
|
313 |
+
"whered": "where'd",
|
314 |
+
"wheres": "where's",
|
315 |
+
"whereve": "where've",
|
316 |
+
"whod": "who'd",
|
317 |
+
"whod've": "who'd've",
|
318 |
+
"who'dve": "who'd've",
|
319 |
+
"wholl": "who'll",
|
320 |
+
"whos": "who's",
|
321 |
+
"whove": "who've",
|
322 |
+
"whyll": "why'll",
|
323 |
+
"whyre": "why're",
|
324 |
+
"whys": "why's",
|
325 |
+
"wont": "won't",
|
326 |
+
"wouldve": "would've",
|
327 |
+
"wouldnt": "wouldn't",
|
328 |
+
"wouldnt've": "wouldn't've",
|
329 |
+
"wouldn'tve": "wouldn't've",
|
330 |
+
"yall": "y'all",
|
331 |
+
"yall'll": "y'all'll",
|
332 |
+
"y'allll": "y'all'll",
|
333 |
+
"yall'd've": "y'all'd've",
|
334 |
+
"y'alld've": "y'all'd've",
|
335 |
+
"y'all'dve": "y'all'd've",
|
336 |
+
"youd": "you'd",
|
337 |
+
"youd've": "you'd've",
|
338 |
+
"you'dve": "you'd've",
|
339 |
+
"youll": "you'll",
|
340 |
+
"youre": "you're",
|
341 |
+
"youve": "you've",
|
342 |
+
}
|
343 |
+
self.manualMap = {
|
344 |
+
"none": "0",
|
345 |
+
"zero": "0",
|
346 |
+
"one": "1",
|
347 |
+
"two": "2",
|
348 |
+
"three": "3",
|
349 |
+
"four": "4",
|
350 |
+
"five": "5",
|
351 |
+
"six": "6",
|
352 |
+
"seven": "7",
|
353 |
+
"eight": "8",
|
354 |
+
"nine": "9",
|
355 |
+
"ten": "10",
|
356 |
+
}
|
357 |
+
self.articles = ["a", "an", "the"]
|
358 |
+
|
359 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
360 |
+
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
361 |
+
self.punct = [
|
362 |
+
";",
|
363 |
+
r"/",
|
364 |
+
"[",
|
365 |
+
"]",
|
366 |
+
'"',
|
367 |
+
"{",
|
368 |
+
"}",
|
369 |
+
"(",
|
370 |
+
")",
|
371 |
+
"=",
|
372 |
+
"+",
|
373 |
+
"\\",
|
374 |
+
"_",
|
375 |
+
"-",
|
376 |
+
">",
|
377 |
+
"<",
|
378 |
+
"@",
|
379 |
+
"`",
|
380 |
+
",",
|
381 |
+
"?",
|
382 |
+
"!",
|
383 |
+
]
|
384 |
+
|
385 |
+
def evaluate(self, quesIds=None):
|
386 |
+
if quesIds == None:
|
387 |
+
quesIds = [quesId for quesId in self.params["question_id"]]
|
388 |
+
gts = {}
|
389 |
+
res = {}
|
390 |
+
for quesId in quesIds:
|
391 |
+
gts[quesId] = self.vqa.qa[quesId]
|
392 |
+
res[quesId] = self.vqaRes.qa[quesId]
|
393 |
+
|
394 |
+
# =================================================
|
395 |
+
# Compute accuracy
|
396 |
+
# =================================================
|
397 |
+
accQA = []
|
398 |
+
accQuesType = {}
|
399 |
+
accAnsType = {}
|
400 |
+
print("computing accuracy")
|
401 |
+
step = 0
|
402 |
+
for quesId in quesIds:
|
403 |
+
for ansDic in gts[quesId]["answers"]:
|
404 |
+
ansDic["answer"] = ansDic["answer"].replace("\n", " ")
|
405 |
+
ansDic["answer"] = ansDic["answer"].replace("\t", " ")
|
406 |
+
ansDic["answer"] = ansDic["answer"].strip()
|
407 |
+
resAns = res[quesId]["answer"]
|
408 |
+
resAns = resAns.replace("\n", " ")
|
409 |
+
resAns = resAns.replace("\t", " ")
|
410 |
+
resAns = resAns.strip()
|
411 |
+
gtAcc = []
|
412 |
+
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
|
413 |
+
|
414 |
+
if len(set(gtAnswers)) > 1:
|
415 |
+
for ansDic in gts[quesId]["answers"]:
|
416 |
+
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
417 |
+
ansDic["answer"] = self.processDigitArticle(ansDic["answer"])
|
418 |
+
resAns = self.processPunctuation(resAns)
|
419 |
+
resAns = self.processDigitArticle(resAns)
|
420 |
+
|
421 |
+
for gtAnsDatum in gts[quesId]["answers"]:
|
422 |
+
otherGTAns = [
|
423 |
+
item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
424 |
+
]
|
425 |
+
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
426 |
+
acc = min(1, float(len(matchingAns)) / 3)
|
427 |
+
gtAcc.append(acc)
|
428 |
+
quesType = gts[quesId]["question_type"]
|
429 |
+
ansType = gts[quesId]["answer_type"]
|
430 |
+
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
431 |
+
accQA.append(avgGTAcc)
|
432 |
+
if quesType not in accQuesType:
|
433 |
+
accQuesType[quesType] = []
|
434 |
+
accQuesType[quesType].append(avgGTAcc)
|
435 |
+
if ansType not in accAnsType:
|
436 |
+
accAnsType[ansType] = []
|
437 |
+
accAnsType[ansType].append(avgGTAcc)
|
438 |
+
self.setEvalQA(quesId, avgGTAcc)
|
439 |
+
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
440 |
+
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
441 |
+
if step % 100 == 0:
|
442 |
+
self.updateProgress(step / float(len(quesIds)))
|
443 |
+
step = step + 1
|
444 |
+
|
445 |
+
self.setAccuracy(accQA, accQuesType, accAnsType)
|
446 |
+
print("Done computing accuracy")
|
447 |
+
|
448 |
+
def processPunctuation(self, inText):
|
449 |
+
outText = inText
|
450 |
+
for p in self.punct:
|
451 |
+
if (p + " " in inText or " " + p in inText) or (
|
452 |
+
re.search(self.commaStrip, inText) != None
|
453 |
+
):
|
454 |
+
outText = outText.replace(p, "")
|
455 |
+
else:
|
456 |
+
outText = outText.replace(p, " ")
|
457 |
+
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
458 |
+
return outText
|
459 |
+
|
460 |
+
def processDigitArticle(self, inText):
|
461 |
+
outText = []
|
462 |
+
tempText = inText.lower().split()
|
463 |
+
for word in tempText:
|
464 |
+
word = self.manualMap.setdefault(word, word)
|
465 |
+
if word not in self.articles:
|
466 |
+
outText.append(word)
|
467 |
+
else:
|
468 |
+
pass
|
469 |
+
for wordId, word in enumerate(outText):
|
470 |
+
if word in self.contractions:
|
471 |
+
outText[wordId] = self.contractions[word]
|
472 |
+
outText = " ".join(outText)
|
473 |
+
return outText
|
474 |
+
|
475 |
+
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
476 |
+
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
477 |
+
self.accuracy["perQuestionType"] = {
|
478 |
+
quesType: round(
|
479 |
+
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
480 |
+
self.n,
|
481 |
+
)
|
482 |
+
for quesType in accQuesType
|
483 |
+
}
|
484 |
+
self.accuracy["perAnswerType"] = {
|
485 |
+
ansType: round(
|
486 |
+
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
487 |
+
)
|
488 |
+
for ansType in accAnsType
|
489 |
+
}
|
490 |
+
|
491 |
+
def setEvalQA(self, quesId, acc):
|
492 |
+
self.evalQA[quesId] = round(100 * acc, self.n)
|
493 |
+
|
494 |
+
def setEvalQuesType(self, quesId, quesType, acc):
|
495 |
+
if quesType not in self.evalQuesType:
|
496 |
+
self.evalQuesType[quesType] = {}
|
497 |
+
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
498 |
+
|
499 |
+
def setEvalAnsType(self, quesId, ansType, acc):
|
500 |
+
if ansType not in self.evalAnsType:
|
501 |
+
self.evalAnsType[ansType] = {}
|
502 |
+
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
503 |
+
|
504 |
+
def updateProgress(self, progress):
|
505 |
+
barLength = 20
|
506 |
+
status = ""
|
507 |
+
if isinstance(progress, int):
|
508 |
+
progress = float(progress)
|
509 |
+
if not isinstance(progress, float):
|
510 |
+
progress = 0
|
511 |
+
status = "error: progress var must be float\r\n"
|
512 |
+
if progress < 0:
|
513 |
+
progress = 0
|
514 |
+
status = "Halt...\r\n"
|
515 |
+
if progress >= 1:
|
516 |
+
progress = 1
|
517 |
+
status = "Done...\r\n"
|
518 |
+
block = int(round(barLength * progress))
|
519 |
+
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
520 |
+
"#" * block + "-" * (barLength - block), int(progress * 100), status
|
521 |
+
)
|
522 |
+
sys.stdout.write(text)
|
523 |
+
sys.stdout.flush()
|
524 |
+
|
525 |
+
|
526 |
+
def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path, vqa_dataset):
|
527 |
+
"""Compute the VQA accuracy metric.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
predictions (List): list of predictions
|
531 |
+
ground_truth (List[List]): list of all possible ground truth answers
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
float: VQA accuracy
|
535 |
+
"""
|
536 |
+
# coding: utf-8
|
537 |
+
# dataDir = data_dir
|
538 |
+
|
539 |
+
# set up file names and paths
|
540 |
+
# versionType = 'v2_' # this should be '' when using VQA v2.0 dataset
|
541 |
+
# 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
542 |
+
# taskType = 'OpenEnded'
|
543 |
+
# 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
544 |
+
# dataType = 'mscoco'
|
545 |
+
# dataSubType = 'train2014'
|
546 |
+
# annFile = '%s/%s%s_%s_annotations.json' % (
|
547 |
+
# dataDir, versionType, dataType, dataSubType)
|
548 |
+
# quesFile = '%s/%s%s_%s_%s_questions.json' % (
|
549 |
+
# dataDir, versionType, taskType, dataType, dataSubType)
|
550 |
+
# imgDir = '%s/%s/%s/' % (dataDir, dataType, dataSubType)
|
551 |
+
# resultType = res_file_name
|
552 |
+
# fileTypes = ['results', 'accuracy',
|
553 |
+
# 'evalQA', 'evalQuesType', 'evalAnsType']
|
554 |
+
|
555 |
+
# An example result json file has been provided in './Results' folder.
|
556 |
+
|
557 |
+
# [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/%s%s_%s_%s_%s_%s.json' % (dataDir, versionType, taskType, dataType, dataSubType,
|
558 |
+
# resultType, fileType) for fileType in fileTypes]
|
559 |
+
|
560 |
+
# create vqa object and vqaRes object
|
561 |
+
vqa = VQA(annotation_json_path, question_json_path)
|
562 |
+
vqaRes = vqa.loadRes(result_json_path, question_json_path)
|
563 |
+
|
564 |
+
# create vqaEval object by taking vqa and vqaRes
|
565 |
+
# n is precision of accuracy (number of places after decimal), default is 2
|
566 |
+
vqaEval = VQAEval(vqa, vqaRes, n=2)
|
567 |
+
|
568 |
+
# evaluate results
|
569 |
+
"""
|
570 |
+
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
571 |
+
By default it uses all the question ids in annotation file
|
572 |
+
"""
|
573 |
+
vqaEval.evaluate()
|
574 |
+
|
575 |
+
return vqaEval.accuracy["overall"]
|
576 |
+
|
577 |
+
|
578 |
+
def postprocess_vqa_generation(predictions):
|
579 |
+
return re.split("Question|Answer", predictions, 1)[0]
|
580 |
+
|
581 |
+
|
582 |
+
def compute_gqa_accuracy(results):
|
583 |
+
acc = []
|
584 |
+
vqa_tool = VQAEval()
|
585 |
+
|
586 |
+
for res in results:
|
587 |
+
gt_ans = res["answers"]
|
588 |
+
pred = res["answer"]
|
589 |
+
pred = vqa_tool.processPunctuation(pred)
|
590 |
+
pred = vqa_tool.processDigitArticle(pred)
|
591 |
+
vqa_acc = 1 if pred == gt_ans else 0
|
592 |
+
acc.append(vqa_acc)
|
593 |
+
accuracy = sum(acc) / len(acc)
|
594 |
+
return accuracy
|
multimodal/build/lib/open_flamingo/src/__init__.py
ADDED
File without changes
|
multimodal/build/lib/open_flamingo/src/attention.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import init
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class SEAttention(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, channel=512,reduction=16):
|
11 |
+
super().__init__()
|
12 |
+
self.fc = nn.Sequential(
|
13 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
14 |
+
nn.GELU(),
|
15 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
16 |
+
nn.GELU(),
|
17 |
+
nn.Linear(channel, 1, bias=False),
|
18 |
+
nn.Sigmoid()
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def init_weights(self):
|
23 |
+
for m in self.modules():
|
24 |
+
if isinstance(m, nn.Conv2d):
|
25 |
+
init.kaiming_normal_(m.weight, mode='fan_out')
|
26 |
+
if m.bias is not None:
|
27 |
+
init.constant_(m.bias, 0)
|
28 |
+
elif isinstance(m, nn.BatchNorm2d):
|
29 |
+
init.constant_(m.weight, 1)
|
30 |
+
init.constant_(m.bias, 0)
|
31 |
+
elif isinstance(m, nn.Linear):
|
32 |
+
init.normal_(m.weight, std=0.001)
|
33 |
+
if m.bias is not None:
|
34 |
+
init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.fc(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
input=torch.randn(50,512,7,7)
|
43 |
+
se = SEAttention(channel=512,reduction=8)
|
44 |
+
output=se(input)
|
45 |
+
print(output.shape)
|
multimodal/build/lib/open_flamingo/src/factory.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .flamingo import Flamingo
|
6 |
+
from .flamingo_lm import FlamingoLMMixin
|
7 |
+
from .utils import extend_instance
|
8 |
+
import logging
|
9 |
+
import random
|
10 |
+
import time
|
11 |
+
|
12 |
+
def create_model_and_transforms(
|
13 |
+
clip_vision_encoder_path: str,
|
14 |
+
clip_vision_encoder_pretrained: str,
|
15 |
+
lang_encoder_path: str,
|
16 |
+
tokenizer_path: str,
|
17 |
+
use_local_files: bool = False,
|
18 |
+
decoder_layers_attr_name: str = None,
|
19 |
+
location_token_num: int = 1000,
|
20 |
+
checkpoint_activations: bool = False,
|
21 |
+
freeze_vision_encoder: bool = False,
|
22 |
+
lora: bool = False,
|
23 |
+
lora_r: int = 16,
|
24 |
+
fix_ffn: bool = False,
|
25 |
+
add_visual_token: bool = False,
|
26 |
+
add_box: bool = False,
|
27 |
+
add_pe: bool = False,
|
28 |
+
add_relation: bool = False,
|
29 |
+
use_format_v2: bool = False,
|
30 |
+
use_sam: str = None,
|
31 |
+
enhance_data: bool = False,
|
32 |
+
roi_align: bool = False,
|
33 |
+
roi_output_size: int = 4,
|
34 |
+
apply_mask: bool = False,
|
35 |
+
**flamingo_kwargs,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
|
39 |
+
Appends special tokens to the tokenizer and freezes backbones.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
|
43 |
+
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
|
44 |
+
lang_encoder_path (str): path to pretrained language encoder
|
45 |
+
tokenizer_path (str): path to pretrained tokenizer
|
46 |
+
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
|
47 |
+
use_local_files (bool, optional): whether to use local files. Defaults to False.
|
48 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
49 |
+
Returns:
|
50 |
+
Flamingo: Flamingo model from pretrained vision and language encoders
|
51 |
+
Image processor: Pipeline to preprocess input images
|
52 |
+
Tokenizer: A tokenizer for the language model
|
53 |
+
"""
|
54 |
+
if use_sam is None:
|
55 |
+
no_success = True
|
56 |
+
while no_success:
|
57 |
+
try:
|
58 |
+
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
|
59 |
+
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
|
60 |
+
)
|
61 |
+
no_success = False
|
62 |
+
except:
|
63 |
+
logging.info("retry creating vision_encoder")
|
64 |
+
time.sleep(random.random() * 5)
|
65 |
+
|
66 |
+
# set the vision encoder to output the visual features
|
67 |
+
vision_encoder.visual.output_tokens = True
|
68 |
+
# delete text encoder part
|
69 |
+
del vision_encoder.transformer
|
70 |
+
del vision_encoder.text_projection
|
71 |
+
del vision_encoder.token_embedding
|
72 |
+
del vision_encoder.ln_final
|
73 |
+
del vision_encoder.positional_embedding
|
74 |
+
del vision_encoder.logit_scale
|
75 |
+
vision_encoder.visual.proj = None
|
76 |
+
vision_encoder.visual.ln_post = torch.nn.Identity()
|
77 |
+
else:
|
78 |
+
from segment_anything import SamPredictor, sam_model_registry
|
79 |
+
assert use_sam == "vit_l"
|
80 |
+
sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth")
|
81 |
+
del sam.prompt_encoder
|
82 |
+
del sam.mask_decoder
|
83 |
+
sam.image_encoder.neck = torch.nn.Identity()
|
84 |
+
vision_encoder = sam.image_encoder
|
85 |
+
from open_clip.transform import image_transform
|
86 |
+
image_processor = image_transform(
|
87 |
+
256,
|
88 |
+
is_train=False,
|
89 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
90 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
91 |
+
)
|
92 |
+
|
93 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
94 |
+
tokenizer_path, local_files_only=use_local_files
|
95 |
+
)
|
96 |
+
# add Flamingo special tokens to the tokenizer
|
97 |
+
additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"]
|
98 |
+
if add_visual_token:
|
99 |
+
additional_special_tokens += ["<|#visual#|>", "<|#object#|>"]
|
100 |
+
if add_box:
|
101 |
+
additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"]
|
102 |
+
if use_format_v2:
|
103 |
+
additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"]
|
104 |
+
if enhance_data:
|
105 |
+
additional_special_tokens += ["<|#NOTHING#|>"]
|
106 |
+
text_tokenizer.add_special_tokens(
|
107 |
+
{"additional_special_tokens": additional_special_tokens}
|
108 |
+
)
|
109 |
+
if text_tokenizer.pad_token is None:
|
110 |
+
# Issue: GPT models don't have a pad token, which we use to
|
111 |
+
# modify labels for the loss.
|
112 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
113 |
+
|
114 |
+
lang_encoder = AutoModelForCausalLM.from_pretrained(
|
115 |
+
lang_encoder_path, local_files_only=use_local_files
|
116 |
+
)
|
117 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
118 |
+
|
119 |
+
if decoder_layers_attr_name is None:
|
120 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
121 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
122 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
123 |
+
lang_encoder_name = lang_encoder.__class__.__name__.lower()
|
124 |
+
if checkpoint_activations:
|
125 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
126 |
+
if use_sam is None:
|
127 |
+
for i in range(len(vision_encoder.visual.transformer.resblocks)):
|
128 |
+
vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper(
|
129 |
+
vision_encoder.visual.transformer.resblocks[i],
|
130 |
+
offload_to_cpu=False,
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
for i in range(len(vision_encoder.blocks)):
|
134 |
+
vision_encoder.blocks[i] = checkpoint_wrapper(
|
135 |
+
vision_encoder.blocks[i],
|
136 |
+
offload_to_cpu=False,
|
137 |
+
)
|
138 |
+
if "opt" in lang_encoder_name:
|
139 |
+
for i in range(len(lang_encoder.model.decoder.layers)):
|
140 |
+
lang_encoder.model.decoder.layers[i] = checkpoint_wrapper(
|
141 |
+
lang_encoder.model.decoder.layers[i],
|
142 |
+
offload_to_cpu=False,
|
143 |
+
)
|
144 |
+
elif "codegen" in lang_encoder_name:
|
145 |
+
for i in range(len(lang_encoder.transformer.h)):
|
146 |
+
lang_encoder.transformer.h[i] = checkpoint_wrapper(
|
147 |
+
lang_encoder.transformer.h[i],
|
148 |
+
offload_to_cpu=False,
|
149 |
+
)
|
150 |
+
elif "llama" in lang_encoder_name:
|
151 |
+
for i in range(len(lang_encoder.model.layers)):
|
152 |
+
lang_encoder.model.layers[i] = checkpoint_wrapper(
|
153 |
+
lang_encoder.model.layers[i],
|
154 |
+
offload_to_cpu=False,
|
155 |
+
)
|
156 |
+
elif "gptneo" in lang_encoder_name:
|
157 |
+
for i in range(len(lang_encoder.gpt_neox.layers)):
|
158 |
+
lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper(
|
159 |
+
lang_encoder.gpt_neox.layers[i],
|
160 |
+
offload_to_cpu=False,
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
raise ValueError(f"unknown model {lang_encoder_name}")
|
164 |
+
if use_sam is None:
|
165 |
+
vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"]
|
166 |
+
image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"]
|
167 |
+
patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"]
|
168 |
+
else:
|
169 |
+
# SAM config
|
170 |
+
vis_dim = 1024
|
171 |
+
image_size = 256
|
172 |
+
patch_size = 16
|
173 |
+
assert image_size % patch_size == 0
|
174 |
+
vis_embed_size = (image_size // patch_size) ** 2
|
175 |
+
|
176 |
+
if lora:
|
177 |
+
from peft import LoraConfig, TaskType
|
178 |
+
from peft import get_peft_model
|
179 |
+
if "codegen" in lang_encoder_name:
|
180 |
+
lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
|
181 |
+
elif "opt" in lang_encoder_name:
|
182 |
+
lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
|
183 |
+
elif "llama" in lang_encoder_name:
|
184 |
+
lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
|
185 |
+
else:
|
186 |
+
raise NotImplementedError
|
187 |
+
lang_peft_config = LoraConfig(
|
188 |
+
task_type="CAUSAL_LM",
|
189 |
+
r=16, lora_alpha=16,
|
190 |
+
target_modules=lang_target_modules,
|
191 |
+
lora_dropout=0.05, bias="none",
|
192 |
+
)
|
193 |
+
lang_encoder = get_peft_model(lang_encoder, lang_peft_config)
|
194 |
+
lang_encoder.print_trainable_parameters()
|
195 |
+
|
196 |
+
if fix_ffn:
|
197 |
+
if "opt" in lang_encoder_name:
|
198 |
+
for i in range(len(lang_encoder.model.decoder.layers)):
|
199 |
+
lang_encoder.model.decoder.layers[i].requires_grad_(False)
|
200 |
+
lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True)
|
201 |
+
else:
|
202 |
+
raise NotImplementedError
|
203 |
+
|
204 |
+
lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size)
|
205 |
+
if hasattr(lang_encoder.config, "word_embed_proj_dim"):
|
206 |
+
hidden_state_dim = lang_encoder.config.word_embed_proj_dim
|
207 |
+
else:
|
208 |
+
hidden_state_dim = lang_encoder.config.hidden_size
|
209 |
+
model = Flamingo(
|
210 |
+
vision_encoder=vision_encoder,
|
211 |
+
lang_encoder=lang_encoder,
|
212 |
+
eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1],
|
213 |
+
media_token_id=text_tokenizer.encode("<|#image#|>")[-1],
|
214 |
+
image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1],
|
215 |
+
visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None,
|
216 |
+
previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None,
|
217 |
+
box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None,
|
218 |
+
prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None,
|
219 |
+
nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None,
|
220 |
+
endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1],
|
221 |
+
vis_dim=vis_dim,
|
222 |
+
vis_embed_size=vis_embed_size,
|
223 |
+
lang_dim=lang_dim,
|
224 |
+
image_size=image_size,
|
225 |
+
patch_size=patch_size,
|
226 |
+
hidden_state_dim=hidden_state_dim,
|
227 |
+
add_visual_token=add_visual_token,
|
228 |
+
add_pe=add_pe,
|
229 |
+
add_relation=add_relation,
|
230 |
+
use_format_v2=use_format_v2,
|
231 |
+
roi_align=roi_align,
|
232 |
+
roi_output_size=roi_output_size,
|
233 |
+
apply_mask=apply_mask,
|
234 |
+
**flamingo_kwargs,
|
235 |
+
)
|
236 |
+
|
237 |
+
if freeze_vision_encoder:
|
238 |
+
print("freeze vision encoder")
|
239 |
+
model.vision_encoder.requires_grad_(False)
|
240 |
+
|
241 |
+
print(
|
242 |
+
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
|
243 |
+
)
|
244 |
+
|
245 |
+
return model, image_processor, text_tokenizer, vis_embed_size
|
246 |
+
|
247 |
+
|
248 |
+
def _infer_decoder_layers_attr_name(model):
|
249 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
250 |
+
if k.lower() in model.__class__.__name__.lower():
|
251 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
252 |
+
|
253 |
+
raise ValueError(
|
254 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
255 |
+
)
|
256 |
+
|
257 |
+
|
258 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
259 |
+
"opt": "model.decoder.layers",
|
260 |
+
# "gptneo": "transformer.h",
|
261 |
+
"gptj": "transformer.h",
|
262 |
+
"gpt-j": "transformer.h",
|
263 |
+
"pythia": "gpt_neox.layers",
|
264 |
+
"gptneox": "gpt_neox.layers",
|
265 |
+
"llama": "model.layers",
|
266 |
+
"llamaforcausallm": "model.layers",
|
267 |
+
"gpt2": "transformer.h",
|
268 |
+
"codegen": "transformer.h",
|
269 |
+
}
|
multimodal/build/lib/open_flamingo/src/flamingo.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from einops import rearrange
|
4 |
+
from torch import nn
|
5 |
+
from yolox.models.yolo_head import YOLOXHead
|
6 |
+
from yolox.utils.boxes import xyxy2cxcywh, cxcywh2xyxy
|
7 |
+
from yolox.utils.demo_utils import nms
|
8 |
+
# import matplotlib.pyplot as plt
|
9 |
+
# import seaborn as sns
|
10 |
+
import numpy as np
|
11 |
+
import logging
|
12 |
+
from open_flamingo.src.gcn import GCN
|
13 |
+
from transformers import LogitsProcessorList
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
format='%(asctime)s %(message)s',
|
17 |
+
datefmt='%m/%d %I:%M:%S',
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
# class PositionEncodingModule(nn.Module):
|
22 |
+
# def __init__(self, dim, pos_dim=128):
|
23 |
+
# super().__init__()
|
24 |
+
# self.encode = nn.Sequential(
|
25 |
+
# nn.Linear(5, pos_dim // 2),
|
26 |
+
# nn.BatchNorm1d(pos_dim // 2),
|
27 |
+
# nn.GELU(),
|
28 |
+
# nn.Linear(pos_dim // 2, pos_dim),
|
29 |
+
# nn.BatchNorm1d(pos_dim),
|
30 |
+
# nn.GELU(),
|
31 |
+
# )
|
32 |
+
# self.merge = nn.Sequential(
|
33 |
+
# nn.Linear(dim + pos_dim, dim),
|
34 |
+
# nn.BatchNorm1d(dim),
|
35 |
+
# nn.GELU(),
|
36 |
+
# )
|
37 |
+
|
38 |
+
# def forward(self, x, box):
|
39 |
+
# box = self.encode(box)
|
40 |
+
# x = torch.cat([x, box], dim=-1)
|
41 |
+
# x = self.merge(x)
|
42 |
+
# return x
|
43 |
+
|
44 |
+
|
45 |
+
# class PositionEncodingModule(nn.Module):
|
46 |
+
# def __init__(self, dim):
|
47 |
+
# super().__init__()
|
48 |
+
# self.encode = nn.Sequential(
|
49 |
+
# nn.Linear(5, dim),
|
50 |
+
# nn.GELU(),
|
51 |
+
# )
|
52 |
+
|
53 |
+
# def forward(self, x, box):
|
54 |
+
# box = self.encode(box)
|
55 |
+
# x = x + box
|
56 |
+
# return x
|
57 |
+
|
58 |
+
|
59 |
+
# class PositionEncodingModule2(nn.Module):
|
60 |
+
# def __init__(self, dim):
|
61 |
+
# super().__init__()
|
62 |
+
# self.encode = nn.Sequential(
|
63 |
+
# nn.Linear(5 + dim, dim),
|
64 |
+
# nn.ELU(),
|
65 |
+
# )
|
66 |
+
|
67 |
+
# def forward(self, x, box):
|
68 |
+
# x = torch.cat([x, box], dim=-1)
|
69 |
+
# x = self.encode(x)
|
70 |
+
# return x
|
71 |
+
|
72 |
+
|
73 |
+
# class RelationHead(nn.Module):
|
74 |
+
# def __init__(self, dim):
|
75 |
+
# super().__init__()
|
76 |
+
# self.encode = nn.Sequential(
|
77 |
+
# nn.LayerNorm(dim),
|
78 |
+
# nn.Linear(dim, 128),
|
79 |
+
# nn.ELU(),
|
80 |
+
# )
|
81 |
+
# self.classifier = nn.Linear(256, 51)
|
82 |
+
|
83 |
+
# def forward(self, x1, x2):
|
84 |
+
# x1 = self.encode(x1)
|
85 |
+
# x2 = self.encode(x2)
|
86 |
+
# x = torch.cat([x1, x2], dim=-1)
|
87 |
+
# x = self.classifier(x)
|
88 |
+
# return x
|
89 |
+
|
90 |
+
|
91 |
+
class Flamingo(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
vision_encoder: nn.Module,
|
95 |
+
lang_encoder: nn.Module,
|
96 |
+
eoc_token_id: int,
|
97 |
+
media_token_id: int,
|
98 |
+
image_end_token_id: int,
|
99 |
+
visual_token_id: int,
|
100 |
+
previsual_token_id: int,
|
101 |
+
box_token_id: int,
|
102 |
+
prebox_token_id: int,
|
103 |
+
nothing_token_id: int,
|
104 |
+
endofobject_token_id: int,
|
105 |
+
vis_dim: int,
|
106 |
+
vis_embed_size: int,
|
107 |
+
lang_dim: int,
|
108 |
+
hidden_state_dim: int,
|
109 |
+
image_size: int,
|
110 |
+
patch_size: int,
|
111 |
+
use_media_placement_augmentation: bool = False,
|
112 |
+
add_visual_token: bool = False,
|
113 |
+
add_pe: bool = False,
|
114 |
+
add_relation: bool = False,
|
115 |
+
use_format_v2: bool = False,
|
116 |
+
roi_align: bool = False,
|
117 |
+
roi_output_size: int = 4,
|
118 |
+
apply_mask: bool = False,
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Args:
|
122 |
+
vision_encoder (nn.Module): HF CLIPModel
|
123 |
+
lang_encoder (nn.Module): HF causal language model
|
124 |
+
eoc_token_id (int): Token id for eos token
|
125 |
+
media_token_id (int): Token id for <|#image#|>
|
126 |
+
vis_dim (int): Dimension of the visual features.
|
127 |
+
Visual features are projected to match this shape along the last dimension.
|
128 |
+
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
|
129 |
+
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
|
130 |
+
"""
|
131 |
+
super().__init__()
|
132 |
+
self.image_end_token_id = image_end_token_id
|
133 |
+
self.eoc_token_id = eoc_token_id
|
134 |
+
self.media_token_id = media_token_id
|
135 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
136 |
+
self.vis_dim = vis_dim
|
137 |
+
self.lang_dim = lang_dim
|
138 |
+
# inner_dim = self.lang_dim * 4
|
139 |
+
# self.vis_proj = nn.Sequential(
|
140 |
+
# nn.LayerNorm(self.vis_dim),
|
141 |
+
# nn.Linear(self.vis_dim, inner_dim, bias=False),
|
142 |
+
# nn.GELU(),
|
143 |
+
# nn.Linear(inner_dim, self.lang_dim, bias=False),
|
144 |
+
# )
|
145 |
+
self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim)
|
146 |
+
self.vision_encoder = vision_encoder
|
147 |
+
self.num_positions = vis_embed_size
|
148 |
+
self.lang_encoder = lang_encoder
|
149 |
+
self.lang_encoder.init_flamingo(
|
150 |
+
media_token_id=media_token_id,
|
151 |
+
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
152 |
+
)
|
153 |
+
first_layer = self.lang_encoder._get_decoder_layers()[0]
|
154 |
+
first_layer.add_visual_token = add_visual_token
|
155 |
+
first_layer.visual_token_id = visual_token_id
|
156 |
+
first_layer.media_token_id = media_token_id
|
157 |
+
first_layer.box_token_id = box_token_id
|
158 |
+
# first_layer.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
|
159 |
+
# assert not (add_pe and add_relation)
|
160 |
+
# self.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
|
161 |
+
# first_layer.pos_enc = self.pos_enc
|
162 |
+
self.box_token_id = box_token_id
|
163 |
+
self.prebox_token_id = prebox_token_id
|
164 |
+
self.media_token_id = media_token_id
|
165 |
+
self.visual_token_id = visual_token_id
|
166 |
+
self.previsual_token_id = previsual_token_id
|
167 |
+
self.hidden_state_dim = hidden_state_dim
|
168 |
+
self.image_size = image_size
|
169 |
+
self.patch_size = patch_size
|
170 |
+
self.patch_num = self.image_size // self.patch_size
|
171 |
+
self.detection_head = YOLOXHead(
|
172 |
+
num_classes=1,
|
173 |
+
strides=[patch_size],
|
174 |
+
in_channels=[self.hidden_state_dim + self.lang_dim],
|
175 |
+
)
|
176 |
+
self.use_format_v2 = use_format_v2
|
177 |
+
self.nothing_token_id = nothing_token_id
|
178 |
+
self.roi_align = roi_align
|
179 |
+
self.roi_output_size = roi_output_size if roi_align else None
|
180 |
+
self.apply_mask = apply_mask
|
181 |
+
self.endofobject_token_id = endofobject_token_id
|
182 |
+
|
183 |
+
|
184 |
+
def _get_detection_batch(
|
185 |
+
self,
|
186 |
+
visual_token_id,
|
187 |
+
previsual_token_id,
|
188 |
+
input_ids: torch.Tensor,
|
189 |
+
hidden_states: torch.Tensor,
|
190 |
+
added_bbox_list,
|
191 |
+
box_num = 100,
|
192 |
+
):
|
193 |
+
select_mask = torch.logical_or(input_ids == visual_token_id, input_ids == previsual_token_id)
|
194 |
+
visual_token_position = select_mask.nonzero()
|
195 |
+
visual_token_hidden_states = hidden_states[select_mask]
|
196 |
+
prev_batch_idx = -1
|
197 |
+
media_idx = []
|
198 |
+
cnt = 0
|
199 |
+
assert len(visual_token_hidden_states) == len(visual_token_position)
|
200 |
+
if len(added_bbox_list) != len(visual_token_position):
|
201 |
+
msg = f"ERROR: {len(added_bbox_list)}:{len(visual_token_position)}\n{added_bbox_list}\n{visual_token_position}"
|
202 |
+
logging.info(msg)
|
203 |
+
alpha = 0.0
|
204 |
+
else:
|
205 |
+
alpha = 1.0
|
206 |
+
visual_batches = []
|
207 |
+
previsual_batches = []
|
208 |
+
for (batch_idx, idx), visual_token_hidden_state, bbox in zip(
|
209 |
+
visual_token_position, visual_token_hidden_states, added_bbox_list,
|
210 |
+
):
|
211 |
+
# ! VERY IMPORTANT BUG !
|
212 |
+
bbox = bbox.clone()
|
213 |
+
# ! VERY IMPORTANT BUG !
|
214 |
+
batch_idx = batch_idx.item()
|
215 |
+
idx = idx.item()
|
216 |
+
if batch_idx != prev_batch_idx:
|
217 |
+
prev_batch_idx = batch_idx
|
218 |
+
this_input_ids = input_ids[batch_idx]
|
219 |
+
cnt += len(media_idx)
|
220 |
+
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
|
221 |
+
for i in range(len(media_idx)):
|
222 |
+
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
|
223 |
+
break
|
224 |
+
image_index = cnt + i
|
225 |
+
size = int(self.image_embedding[image_index].shape[0] ** 0.5)
|
226 |
+
image_embedding = self.image_embedding[image_index]
|
227 |
+
# inplace xyxy2cxcywh
|
228 |
+
# print(bbox)
|
229 |
+
# TODO: CHECK self.image_size. Is it 224?
|
230 |
+
bbox = xyxy2cxcywh(bbox) * self.image_size
|
231 |
+
# print(bbox)
|
232 |
+
concat_image_visual_embedding = torch.cat([image_embedding, visual_token_hidden_state.unsqueeze(0).repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1)
|
233 |
+
label = torch.cat([torch.zeros(bbox.shape[0], 1, device=bbox.device), bbox], dim=-1)
|
234 |
+
label = torch.cat([label, torch.zeros(box_num - label.shape[0], label.shape[1], device=label.device)], dim=0)
|
235 |
+
if input_ids[batch_idx, idx] == previsual_token_id:
|
236 |
+
previsual_batches.append([concat_image_visual_embedding, label])
|
237 |
+
elif input_ids[batch_idx, idx] == visual_token_id:
|
238 |
+
visual_batches.append([concat_image_visual_embedding, label])
|
239 |
+
else:
|
240 |
+
logging.info(f"WARNING... NOT visual nor previsual. it is {input_ids[batch_idx, idx]}")
|
241 |
+
return visual_batches, previsual_batches, alpha, alpha
|
242 |
+
|
243 |
+
def get_detection_losses(
|
244 |
+
self,
|
245 |
+
input_ids: torch.Tensor,
|
246 |
+
hidden_states: torch.Tensor,
|
247 |
+
added_bbox_list,
|
248 |
+
box_num = 100,
|
249 |
+
):
|
250 |
+
visual_token_batches, previsual_token_batches, alpha1, alpha2 = self._get_detection_batch(
|
251 |
+
visual_token_id=self.visual_token_id,
|
252 |
+
previsual_token_id=self.previsual_token_id,
|
253 |
+
input_ids=input_ids,
|
254 |
+
hidden_states=hidden_states,
|
255 |
+
added_bbox_list=added_bbox_list,
|
256 |
+
box_num=box_num,
|
257 |
+
)
|
258 |
+
loss_dict = []
|
259 |
+
for batches, alpha in zip([visual_token_batches, previsual_token_batches], [alpha1, alpha2]):
|
260 |
+
# x: [B, C, H, W]
|
261 |
+
if len(batches) != 0:
|
262 |
+
x = torch.cat([batch[0].unsqueeze(0) for batch in batches], dim=0).permute(0,3,1,2)
|
263 |
+
labels = torch.cat([batch[1].unsqueeze(0) for batch in batches], dim=0)
|
264 |
+
else:
|
265 |
+
x = None
|
266 |
+
labels = None
|
267 |
+
if x is not None:
|
268 |
+
losses = self.detection_head(xin=[x], labels=labels)
|
269 |
+
loss, loss_iou, loss_obj, loss_cls, loss_l1, _ = losses
|
270 |
+
else:
|
271 |
+
loss = torch.tensor(0.0).cuda()
|
272 |
+
loss_iou = loss
|
273 |
+
loss_obj = loss
|
274 |
+
loss_cls = loss
|
275 |
+
loss_l1 = loss
|
276 |
+
|
277 |
+
loss_dict.append(dict(
|
278 |
+
loss=loss * alpha,
|
279 |
+
loss_iou=loss_iou * alpha,
|
280 |
+
loss_obj=loss_obj * alpha,
|
281 |
+
loss_cls=loss_cls * alpha,
|
282 |
+
loss_l1=loss_l1 * alpha,
|
283 |
+
))
|
284 |
+
ret_loss = {}
|
285 |
+
for key in loss_dict[0].keys():
|
286 |
+
ret_loss[key] = 0.0
|
287 |
+
for d in loss_dict:
|
288 |
+
ret_loss[key] += d[key]
|
289 |
+
return ret_loss, loss_dict
|
290 |
+
|
291 |
+
def get_detection_result(
|
292 |
+
self,
|
293 |
+
input_ids: torch.Tensor,
|
294 |
+
hidden_states: torch.Tensor,
|
295 |
+
nms_thr: float = 0.45,
|
296 |
+
score_thr: float = 0.01,
|
297 |
+
debug_id: int = 0,
|
298 |
+
debug_mode: bool = False,
|
299 |
+
):
|
300 |
+
assert len(input_ids) == 1, "only batch size = 1 is supported yet"
|
301 |
+
# assert len(self.image_embedding) == 1, "only one image is supported yet"
|
302 |
+
# assert (input_ids[..., -1] == self.visual_token_id).all(), "the last token should be visual token"
|
303 |
+
visual_token_hidden_state = hidden_states[..., -1, :]
|
304 |
+
boxes_list = []
|
305 |
+
scores_list = []
|
306 |
+
for image_embedding in self.image_embedding:
|
307 |
+
size = int(image_embedding.shape[0] ** 0.5)
|
308 |
+
x = torch.cat([image_embedding, visual_token_hidden_state.repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1).unsqueeze(0).permute(0,3,1,2)
|
309 |
+
with torch.no_grad():
|
310 |
+
outputs = self.detection_head(xin=[x], labels=None)
|
311 |
+
boxes = outputs[0,:,:4].cpu().numpy()
|
312 |
+
scores = outputs[0,:,4].cpu().numpy()
|
313 |
+
scores_mask = scores > score_thr
|
314 |
+
boxes = boxes[scores_mask]
|
315 |
+
boxes = cxcywh2xyxy(boxes)
|
316 |
+
scores = scores[scores_mask]
|
317 |
+
keep = nms(boxes, scores, nms_thr=nms_thr)
|
318 |
+
boxes = boxes[keep]
|
319 |
+
scores = scores[keep]
|
320 |
+
if debug_mode:
|
321 |
+
obj_heatmap = outputs[0,:, -2].reshape(size, size).cpu().numpy()
|
322 |
+
import matplotlib.pyplot as plt
|
323 |
+
import seaborn as sns
|
324 |
+
plt.figure()
|
325 |
+
sns_plot = sns.heatmap(obj_heatmap)
|
326 |
+
plt.savefig(f"heatmap_{debug_id}.jpg")
|
327 |
+
debug_id += 1
|
328 |
+
boxes_list.append(boxes)
|
329 |
+
scores_list.append(scores)
|
330 |
+
if len(boxes_list) == 1:
|
331 |
+
boxes_list = boxes_list[0]
|
332 |
+
scores_list = scores_list[0]
|
333 |
+
return boxes_list, scores_list
|
334 |
+
|
335 |
+
def _condition_attention(self, loc_list = None):
|
336 |
+
for i in range(len(self.lang_encoder.gpt_neox.layers)):
|
337 |
+
self.lang_encoder.gpt_neox.layers[i].decoder_layer.attention.loc_list = loc_list
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
vision_x: torch.Tensor,
|
342 |
+
lang_x: torch.Tensor,
|
343 |
+
attention_mask: torch.Tensor = None,
|
344 |
+
labels: torch.Tensor = None,
|
345 |
+
use_cached_vision_x: bool = False,
|
346 |
+
clear_conditioned_layers: bool = True,
|
347 |
+
past_key_values=None,
|
348 |
+
use_cache: bool = False,
|
349 |
+
image_nums=None,
|
350 |
+
image_start_index_list=None,
|
351 |
+
added_bbox_list=None,
|
352 |
+
add_box: bool = False,
|
353 |
+
relations=None,
|
354 |
+
debug_mode: bool = False,
|
355 |
+
):
|
356 |
+
"""
|
357 |
+
Forward pass of Flamingo.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
vision_x (torch.Tensor): Vision input
|
361 |
+
shape (B, T_img, F, C, H, W) with F=1
|
362 |
+
lang_x (torch.Tensor): Language input ids
|
363 |
+
shape (B, T_txt)
|
364 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
365 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
366 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
367 |
+
once the foward pass is completed. Set this to false if the
|
368 |
+
same set of images will be reused in another subsequent
|
369 |
+
forward pass.
|
370 |
+
past_key_values: pre-computed values to pass to language model.
|
371 |
+
See past_key_values documentation in Hugging Face
|
372 |
+
CausalLM models.
|
373 |
+
use_cache: whether to use cached key values. See use_cache
|
374 |
+
documentation in Hugging Face CausalLM models.
|
375 |
+
"""
|
376 |
+
self.valid = True
|
377 |
+
self.lang_encoder.loc_list = None
|
378 |
+
if use_cached_vision_x:
|
379 |
+
# Case: use cached; vision_x should be cached and other
|
380 |
+
# vision-related inputs should not be provided.
|
381 |
+
assert (
|
382 |
+
vision_x is None
|
383 |
+
), "Expect vision_x to be None when use_cached_vision_x is True."
|
384 |
+
assert self.lang_encoder.is_conditioned()
|
385 |
+
else:
|
386 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
387 |
+
self._encode_vision_x(
|
388 |
+
vision_x=vision_x,
|
389 |
+
image_nums=image_nums,
|
390 |
+
image_start_index_list=image_start_index_list,
|
391 |
+
added_bbox_list=added_bbox_list if add_box else None,
|
392 |
+
input_ids=lang_x,
|
393 |
+
relations=relations,
|
394 |
+
)
|
395 |
+
if self.apply_mask:
|
396 |
+
if self.roi_align:
|
397 |
+
attend_length = 1 + self.roi_output_size ** 2
|
398 |
+
else:
|
399 |
+
attend_length = 2
|
400 |
+
prebox_loc = (lang_x == self.prebox_token_id).nonzero()
|
401 |
+
loc_list = []
|
402 |
+
for (x, y) in prebox_loc:
|
403 |
+
x = x.item()
|
404 |
+
y = y.item()
|
405 |
+
for yy in range(y+1, lang_x.shape[1]):
|
406 |
+
if lang_x[x, yy] == self.endofobject_token_id:
|
407 |
+
# [batch_idx, [previsual:prebox], [object:endofobject-1]]
|
408 |
+
loc_list.append([x, [y-attend_length+1, y], [y+1, yy-1]])
|
409 |
+
self._condition_attention(loc_list=loc_list)
|
410 |
+
else:
|
411 |
+
self._condition_attention(None)
|
412 |
+
|
413 |
+
output = self.lang_encoder(
|
414 |
+
input_ids=lang_x,
|
415 |
+
attention_mask=attention_mask,
|
416 |
+
labels=labels,
|
417 |
+
past_key_values=past_key_values,
|
418 |
+
use_cache=use_cache,
|
419 |
+
output_hidden_states=True,
|
420 |
+
)
|
421 |
+
if vision_x is None:
|
422 |
+
output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean()
|
423 |
+
|
424 |
+
hidden_states = output["hidden_states"][-1]
|
425 |
+
if self.training and added_bbox_list is not None:
|
426 |
+
detection_losses, loss_dict = self.get_detection_losses(
|
427 |
+
input_ids=lang_x,
|
428 |
+
hidden_states=hidden_states,
|
429 |
+
added_bbox_list=added_bbox_list,
|
430 |
+
)
|
431 |
+
output["detection_losses"] = detection_losses
|
432 |
+
output["loss_dict"] = loss_dict
|
433 |
+
elif labels is None:
|
434 |
+
boxes, scores = self.get_detection_result(
|
435 |
+
input_ids=lang_x,
|
436 |
+
hidden_states=hidden_states,
|
437 |
+
debug_id=self.debug_id if hasattr(self, "debug_id") else None,
|
438 |
+
debug_mode=debug_mode,
|
439 |
+
)
|
440 |
+
output["boxes"] = boxes
|
441 |
+
output["scores"] = scores
|
442 |
+
|
443 |
+
if clear_conditioned_layers:
|
444 |
+
self.lang_encoder.clear_conditioned_layers()
|
445 |
+
self._condition_attention(None)
|
446 |
+
return output
|
447 |
+
|
448 |
+
def generate(
|
449 |
+
self,
|
450 |
+
vision_x: torch.Tensor,
|
451 |
+
lang_x: torch.Tensor,
|
452 |
+
attention_mask: torch.Tensor = None,
|
453 |
+
added_bbox_list=None,
|
454 |
+
num_beams=1,
|
455 |
+
max_new_tokens=None,
|
456 |
+
temperature=1.0,
|
457 |
+
top_k=0,
|
458 |
+
top_p=1.0,
|
459 |
+
no_repeat_ngram_size=0,
|
460 |
+
prefix_allowed_tokens_fn=None,
|
461 |
+
length_penalty=1.0,
|
462 |
+
num_return_sequences=1,
|
463 |
+
do_sample=False,
|
464 |
+
early_stopping=False,
|
465 |
+
bad_words_ids=None,
|
466 |
+
force_words_ids=None,
|
467 |
+
image_start_index_list=None,
|
468 |
+
image_nums=None,
|
469 |
+
min_length=None,
|
470 |
+
return_dict_in_generate=False,
|
471 |
+
output_hidden_states=False,
|
472 |
+
output_scores=False,
|
473 |
+
logits_processor_list=None,
|
474 |
+
eos_token_id=None,
|
475 |
+
):
|
476 |
+
"""
|
477 |
+
Generate text conditioned on vision and language inputs.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
vision_x (torch.Tensor): Vision input
|
481 |
+
shape (B, T_img, F, C, H, W)
|
482 |
+
images in the same chunk are collated along T_img, and frames are collated along F
|
483 |
+
currently only F=1 is supported (single-frame videos)
|
484 |
+
lang_x (torch.Tensor): Language input
|
485 |
+
shape (B, T_txt)
|
486 |
+
max_length (int, optional): Maximum length of the output. Defaults to None.
|
487 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
488 |
+
num_beams (int, optional): Number of beams. Defaults to 1.
|
489 |
+
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
490 |
+
temperature (float, optional): Temperature. Defaults to 1.0.
|
491 |
+
top_k (int, optional): Top k. Defaults to 0.
|
492 |
+
top_p (float, optional): Top p. Defaults to 1.0.
|
493 |
+
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
494 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
495 |
+
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
496 |
+
do_sample (bool, optional): Do sample. Defaults to False.
|
497 |
+
early_stopping (bool, optional): Early stopping. Defaults to False.
|
498 |
+
Returns:
|
499 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
500 |
+
"""
|
501 |
+
if num_beams > 1:
|
502 |
+
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
|
503 |
+
image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist()
|
504 |
+
image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist()
|
505 |
+
if added_bbox_list is not None and len(added_bbox_list) != 0:
|
506 |
+
added_bbox_list = added_bbox_list * num_beams
|
507 |
+
|
508 |
+
self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams, added_bbox_list=added_bbox_list, input_ids=lang_x.repeat_interleave(num_beams, dim=0))
|
509 |
+
|
510 |
+
if logits_processor_list is not None:
|
511 |
+
assert isinstance(logits_processor_list, list)
|
512 |
+
logits_processor_list = LogitsProcessorList(logits_processor_list)
|
513 |
+
output = self.lang_encoder.generate(
|
514 |
+
input_ids=lang_x,
|
515 |
+
attention_mask=attention_mask,
|
516 |
+
eos_token_id=(self.eoc_token_id) if eos_token_id is None else eos_token_id,
|
517 |
+
num_beams=num_beams,
|
518 |
+
max_new_tokens=max_new_tokens,
|
519 |
+
min_length=min_length,
|
520 |
+
length_penalty=length_penalty,
|
521 |
+
logits_processor=logits_processor_list,
|
522 |
+
return_dict_in_generate=return_dict_in_generate,
|
523 |
+
output_scores=output_scores,
|
524 |
+
)
|
525 |
+
self.lang_encoder.clear_conditioned_layers()
|
526 |
+
return output
|
527 |
+
|
528 |
+
def _get_data_list_and_visual_tokens(
|
529 |
+
self,
|
530 |
+
all_box_list,
|
531 |
+
box_token_id,
|
532 |
+
prebox_token_id,
|
533 |
+
input_ids,
|
534 |
+
vision_x,
|
535 |
+
nothing_embedding = None,
|
536 |
+
):
|
537 |
+
box_locations = (torch.logical_or(input_ids == box_token_id, input_ids == prebox_token_id)).nonzero()
|
538 |
+
prev_batch_idx = -1
|
539 |
+
media_idx = []
|
540 |
+
cnt = 0
|
541 |
+
data_list = []
|
542 |
+
visual_tokens = []
|
543 |
+
if len(all_box_list) != len(box_locations):
|
544 |
+
logging.info(f"WARNING. len(all_box_list) != len(box_locations) {len(all_box_list)} vs {len(box_locations)}")
|
545 |
+
self.valid = False
|
546 |
+
for III, (batch_idx, idx) in enumerate(box_locations):
|
547 |
+
batch_idx = batch_idx.item()
|
548 |
+
idx = idx.item()
|
549 |
+
if batch_idx != prev_batch_idx:
|
550 |
+
prev_batch_idx = batch_idx
|
551 |
+
this_input_ids = input_ids[batch_idx]
|
552 |
+
cnt += len(media_idx)
|
553 |
+
media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
|
554 |
+
for i in range(len(media_idx)):
|
555 |
+
if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
|
556 |
+
break
|
557 |
+
image_index = cnt + i
|
558 |
+
size = int(vision_x[image_index].shape[0] ** 0.5)
|
559 |
+
image_feature = vision_x[image_index].reshape(size, size, -1)
|
560 |
+
try:
|
561 |
+
raw_xyxy = all_box_list[III]
|
562 |
+
except:
|
563 |
+
logging.info("out of scope for all_box_list")
|
564 |
+
raw_xyxy = all_box_list[-1]
|
565 |
+
region_xyxy = np.array(raw_xyxy) * size
|
566 |
+
x1, y1, x2, y2 = region_xyxy.astype(int).clip(0, size-1).tolist()
|
567 |
+
x2 = max(x1, x2)
|
568 |
+
y2 = max(y1, y2)
|
569 |
+
if x1 + y1 + x2 + y2 == 0.0 and nothing_embedding is not None:
|
570 |
+
visual_token = nothing_embedding
|
571 |
+
else:
|
572 |
+
if self.roi_align:
|
573 |
+
visual_token = torchvision.ops.roi_align(
|
574 |
+
image_feature.permute(2, 0, 1).unsqueeze(0),
|
575 |
+
[torch.tensor(region_xyxy.astype(np.float32)).unsqueeze(0).cuda()],
|
576 |
+
output_size=self.roi_output_size,
|
577 |
+
spatial_scale=1.0,
|
578 |
+
)
|
579 |
+
visual_token = visual_token.squeeze(0).flatten(1).permute(1, 0)
|
580 |
+
else:
|
581 |
+
visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0)
|
582 |
+
box = torch.tensor([0] + raw_xyxy, device=visual_token.device, dtype=visual_token.dtype)
|
583 |
+
data_list.append([visual_token, box, batch_idx, idx, i])
|
584 |
+
visual_tokens.append(visual_token)
|
585 |
+
return data_list, visual_tokens
|
586 |
+
|
587 |
+
def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, num_beams=None, input_ids=None, relations=None):
|
588 |
+
"""
|
589 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
590 |
+
Args:
|
591 |
+
vision_x (torch.Tensor): Vision input
|
592 |
+
shape (B, T_img, F, C, H, W)
|
593 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
594 |
+
Currently only F=1 is supported (single-frame videos)
|
595 |
+
|
596 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
597 |
+
"""
|
598 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
599 |
+
b, T, F = vision_x.shape[:3]
|
600 |
+
assert F == 1, "Only single frame supported"
|
601 |
+
|
602 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
603 |
+
if hasattr(self.vision_encoder, "visual"):
|
604 |
+
vision_x = self.vision_encoder.visual(vision_x)[1]
|
605 |
+
else:
|
606 |
+
vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1)
|
607 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
608 |
+
|
609 |
+
# print(vision_x[0,0,0])
|
610 |
+
# # DEBUG HERE
|
611 |
+
# if torch.distributed.get_rank() == 0:
|
612 |
+
# import pdb; pdb.set_trace()
|
613 |
+
# else:
|
614 |
+
# torch.distributed.barrier()
|
615 |
+
vision_x = vision_x.mean(2)
|
616 |
+
# vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
617 |
+
# vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0)
|
618 |
+
vision_x = self.vis_proj(vision_x).squeeze(1)
|
619 |
+
self.image_embedding = vision_x
|
620 |
+
|
621 |
+
data_list = None
|
622 |
+
visual_tokens = None
|
623 |
+
if added_bbox_list is not None and input_ids is not None:
|
624 |
+
all_box_list = added_bbox_list[0].tolist()
|
625 |
+
for list in added_bbox_list[1:]:
|
626 |
+
all_box_list.extend(list.tolist())
|
627 |
+
data_list, visual_tokens = self._get_data_list_and_visual_tokens(
|
628 |
+
all_box_list=all_box_list,
|
629 |
+
box_token_id=self.box_token_id,
|
630 |
+
prebox_token_id=self.prebox_token_id,
|
631 |
+
input_ids=input_ids,
|
632 |
+
vision_x=vision_x,
|
633 |
+
nothing_embedding=self.lang_encoder.gpt_neox.embed_in(torch.tensor(self.nothing_token_id).to(self.lang_encoder.gpt_neox.embed_in.weight.device)) if self.nothing_token_id is not None else None,
|
634 |
+
)
|
635 |
+
|
636 |
+
first_layer = self.lang_encoder._get_decoder_layers()[0]
|
637 |
+
first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, num_beams=num_beams, visual_tokens=visual_tokens, data_list=[[d[2], d[3]] for d in data_list] if data_list is not None else data_list)
|
multimodal/build/lib/open_flamingo/src/flamingo_lm.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .helpers import GatedCrossAttentionBlock
|
7 |
+
from .utils import getattr_recursive, setattr_recursive
|
8 |
+
|
9 |
+
|
10 |
+
class FlamingoLayer(nn.Module):
|
11 |
+
def __init__(self, decoder_layer):
|
12 |
+
super().__init__()
|
13 |
+
self.decoder_layer = decoder_layer
|
14 |
+
self.vis_x = None
|
15 |
+
self.image_nums = None
|
16 |
+
self.image_start_index_list = None
|
17 |
+
self.media_locations = None
|
18 |
+
self.add_visual_token = False
|
19 |
+
self.input_ids = None
|
20 |
+
|
21 |
+
def is_conditioned(self) -> bool:
|
22 |
+
"""Check whether the layer is conditioned."""
|
23 |
+
return self.vis_x is not None
|
24 |
+
|
25 |
+
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
26 |
+
def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None):
|
27 |
+
self.vis_x = vis_x
|
28 |
+
self.image_nums = image_nums
|
29 |
+
self.image_start_index_list = image_start_index_list
|
30 |
+
self.num_beams = num_beams
|
31 |
+
self.visual_tokens = visual_tokens
|
32 |
+
self.data_list = data_list
|
33 |
+
self.input_ids = None
|
34 |
+
|
35 |
+
|
36 |
+
def condition_media_locations(self, media_locations):
|
37 |
+
self.media_locations = media_locations
|
38 |
+
|
39 |
+
def condition_attend_previous(self, attend_previous):
|
40 |
+
self.attend_previous = attend_previous
|
41 |
+
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
hidden_states, # alignment with hugging face name
|
45 |
+
attention_mask=None,
|
46 |
+
**decoder_layer_kwargs,
|
47 |
+
):
|
48 |
+
if self.media_locations is None:
|
49 |
+
raise ValueError("media_locations must be conditioned before forward pass")
|
50 |
+
|
51 |
+
if self.vis_x is not None:
|
52 |
+
if self.training:
|
53 |
+
single_length = self.vis_x.shape[-2]
|
54 |
+
image_nums = self.image_nums
|
55 |
+
image_start_index_list = self.image_start_index_list
|
56 |
+
image_nums = [0] + np.cumsum(image_nums).tolist()
|
57 |
+
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
|
58 |
+
for index in start_indices:
|
59 |
+
if image_num_begin < image_num_end:
|
60 |
+
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
|
61 |
+
image_num_begin += 1
|
62 |
+
|
63 |
+
if self.visual_tokens is not None and len(self.visual_tokens) != 0:
|
64 |
+
for i, (x, y) in enumerate(self.data_list):
|
65 |
+
if len(self.visual_tokens[i].shape) > 1:
|
66 |
+
# print(self.visual_tokens[i].shape[0], "embedding")
|
67 |
+
hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
|
68 |
+
else:
|
69 |
+
# print(self.visual_tokens[i].shape[0], "embedding")
|
70 |
+
hidden_states[x, y] = self.visual_tokens[i]
|
71 |
+
|
72 |
+
elif not self.training:
|
73 |
+
if (
|
74 |
+
("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
|
75 |
+
("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
|
76 |
+
):
|
77 |
+
single_length = self.vis_x.shape[-2]
|
78 |
+
image_nums = self.image_nums
|
79 |
+
image_start_index_list = self.image_start_index_list
|
80 |
+
image_nums = [0] + np.cumsum(image_nums).tolist()
|
81 |
+
for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
|
82 |
+
for index in start_indices:
|
83 |
+
if image_num_begin < image_num_end:
|
84 |
+
hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
|
85 |
+
image_num_begin += 1
|
86 |
+
if self.visual_tokens is not None and len(self.visual_tokens) != 0:
|
87 |
+
for i, (x, y) in enumerate(self.data_list):
|
88 |
+
# import pdb; pdb.set_trace()
|
89 |
+
# print(x, y, self.visual_tokens[i].shape)
|
90 |
+
if len(self.visual_tokens[i].shape) > 1:
|
91 |
+
# print(self.visual_tokens[i].shape[0], "embedding")
|
92 |
+
hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
|
93 |
+
else:
|
94 |
+
# print(self.visual_tokens[i].shape[0], "embedding")
|
95 |
+
hidden_states[x, y] = self.visual_tokens[i]
|
96 |
+
hidden_states = self.decoder_layer(
|
97 |
+
hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
|
98 |
+
)
|
99 |
+
return hidden_states
|
100 |
+
|
101 |
+
|
102 |
+
class FlamingoLMMixin(nn.Module):
|
103 |
+
"""
|
104 |
+
Mixin to add cross-attention layers to a language model.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
108 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
109 |
+
|
110 |
+
def _get_decoder_layers(self):
|
111 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
112 |
+
|
113 |
+
def _set_decoder_layers(self, value):
|
114 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
115 |
+
|
116 |
+
def init_flamingo(
|
117 |
+
self,
|
118 |
+
media_token_id,
|
119 |
+
use_media_placement_augmentation,
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
123 |
+
"""
|
124 |
+
self._set_decoder_layers(
|
125 |
+
nn.ModuleList(
|
126 |
+
[FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
|
127 |
+
)
|
128 |
+
)
|
129 |
+
self.media_token_id = media_token_id
|
130 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
131 |
+
self.initialized_flamingo = True
|
132 |
+
|
133 |
+
def forward(self, *input, **kwargs):
|
134 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
135 |
+
if not self.initialized_flamingo:
|
136 |
+
raise ValueError(
|
137 |
+
"Flamingo layers are not initialized. Please call `init_flamingo` first."
|
138 |
+
)
|
139 |
+
|
140 |
+
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
|
141 |
+
media_locations = input_ids == self.media_token_id
|
142 |
+
attend_previous = (
|
143 |
+
(random.random() < 0.5) if self.use_media_placement_augmentation else True
|
144 |
+
)
|
145 |
+
|
146 |
+
if (
|
147 |
+
"gpt2" in self.__class__.__name__.lower()
|
148 |
+
or "codegen" in self.__class__.__name__.lower()
|
149 |
+
):
|
150 |
+
for layer in self.transformer.h:
|
151 |
+
layer.condition_media_locations(media_locations)
|
152 |
+
layer.condition_attend_previous(attend_previous)
|
153 |
+
elif "gptneox" in self.__class__.__name__.lower():
|
154 |
+
for layer in self.gpt_neox.layers:
|
155 |
+
layer.condition_media_locations(media_locations)
|
156 |
+
layer.condition_attend_previous(attend_previous)
|
157 |
+
else:
|
158 |
+
for layer in self.get_decoder().layers:
|
159 |
+
layer.condition_media_locations(media_locations)
|
160 |
+
layer.condition_attend_previous(attend_previous)
|
161 |
+
return super().forward(
|
162 |
+
*input, **kwargs
|
163 |
+
) # Call the other parent's forward method
|
164 |
+
|
165 |
+
def is_conditioned(self) -> bool:
|
166 |
+
"""Check whether all decoder layers are already conditioned."""
|
167 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
168 |
+
|
169 |
+
def clear_conditioned_layers(self):
|
170 |
+
for layer in self._get_decoder_layers():
|
171 |
+
layer.condition_vis_x(None)
|
172 |
+
layer.condition_media_locations(None)
|
173 |
+
layer.condition_attend_previous(None)
|
multimodal/build/lib/open_flamingo/src/gcn.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
import math
|
6 |
+
from torch.autograd import Variable
|
7 |
+
from torchvision.ops import box_iou
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class GraphConvolution(nn.Module):
|
12 |
+
"""
|
13 |
+
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, in_features, out_features, bias=True, skip=True):
|
17 |
+
super(GraphConvolution, self).__init__()
|
18 |
+
self.skip = skip
|
19 |
+
self.in_features = in_features
|
20 |
+
self.out_features = out_features
|
21 |
+
self.weight = Parameter(torch.Tensor(in_features, out_features))
|
22 |
+
if bias:
|
23 |
+
self.bias = Parameter(torch.Tensor(out_features))
|
24 |
+
else:
|
25 |
+
self.register_parameter('bias', None)
|
26 |
+
self.reset_parameters()
|
27 |
+
|
28 |
+
def reset_parameters(self):
|
29 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
30 |
+
self.weight.data.uniform_(-stdv, stdv)
|
31 |
+
if self.bias is not None:
|
32 |
+
self.bias.data.uniform_(-stdv, stdv)
|
33 |
+
|
34 |
+
def forward(self, input, adj):
|
35 |
+
# TODO make fc more efficient via "pack_padded_sequence"
|
36 |
+
# import ipdb; ipdb.set_trace()
|
37 |
+
support = torch.bmm(input, self.weight.unsqueeze(
|
38 |
+
0).expand(input.shape[0], -1, -1))
|
39 |
+
output = torch.bmm(adj, support)
|
40 |
+
#output = SparseMM(adj)(support)
|
41 |
+
if self.bias is not None:
|
42 |
+
output += self.bias.unsqueeze(0).expand(input.shape[0], -1, -1)
|
43 |
+
if self.skip:
|
44 |
+
output += support
|
45 |
+
|
46 |
+
return output
|
47 |
+
|
48 |
+
def __repr__(self):
|
49 |
+
return self.__class__.__name__ + ' (' \
|
50 |
+
+ str(self.in_features) + ' -> ' \
|
51 |
+
+ str(self.out_features) + ')'
|
52 |
+
|
53 |
+
|
54 |
+
class GCN_sim(nn.Module):
|
55 |
+
def __init__(self, dim_in, dim_hidden, dim_out, dropout, num_layers):
|
56 |
+
super(GCN_sim, self).__init__()
|
57 |
+
assert num_layers >= 1
|
58 |
+
self.fc_k = nn.Linear(dim_in, dim_hidden)
|
59 |
+
self.fc_q = nn.Linear(dim_in, dim_hidden)
|
60 |
+
|
61 |
+
dim_hidden = dim_out if num_layers == 1 else dim_hidden
|
62 |
+
self.gcs = nn.ModuleList([
|
63 |
+
GraphConvolution(dim_in, dim_hidden)
|
64 |
+
])
|
65 |
+
|
66 |
+
for i in range(num_layers - 1):
|
67 |
+
dim_tmp = dim_out if i == num_layers-2 else dim_hidden
|
68 |
+
self.gcs.append(GraphConvolution(dim_hidden, dim_tmp))
|
69 |
+
|
70 |
+
self.dropout = dropout
|
71 |
+
|
72 |
+
def construct_graph(self, x, length):
|
73 |
+
# TODO make fc more efficient via "pack_padded_sequence"
|
74 |
+
emb_k = self.fc_k(x)
|
75 |
+
emb_q = self.fc_q(x)
|
76 |
+
|
77 |
+
s = torch.bmm(emb_k, emb_q.transpose(1, 2))
|
78 |
+
|
79 |
+
s_mask = s.data.new(*s.size()).fill_(1).bool() # [B, T1, T2]
|
80 |
+
# Init similarity mask using lengths
|
81 |
+
for i, (l_1, l_2) in enumerate(zip(length, length)):
|
82 |
+
s_mask[i][:l_1, :l_2] = 0
|
83 |
+
s_mask = Variable(s_mask)
|
84 |
+
s.data.masked_fill_(s_mask.data, -float("inf"))
|
85 |
+
|
86 |
+
a_weight = F.softmax(s, dim=2) # [B, t1, t2]
|
87 |
+
# remove nan from softmax on -inf
|
88 |
+
a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)
|
89 |
+
|
90 |
+
return a_weight
|
91 |
+
|
92 |
+
def forward(self, x, length):
|
93 |
+
adj_sim = self.construct_graph(x, length)
|
94 |
+
|
95 |
+
for gc in self.gcs:
|
96 |
+
x = F.relu(gc(x, adj_sim))
|
97 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
98 |
+
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
class GCN(nn.Module):
|
103 |
+
def __init__(self, dim_in, dim_hidden, dim_out, dropout, mode, skip, num_layers, ST_n_next=None):
|
104 |
+
super(GCN, self).__init__()
|
105 |
+
assert len(mode) != 0
|
106 |
+
self.mode = mode
|
107 |
+
self.skip = skip
|
108 |
+
|
109 |
+
if "GCN_sim" in mode:
|
110 |
+
self.GCN_sim = GCN_sim(
|
111 |
+
dim_in, dim_hidden, dim_out, dropout, num_layers)
|
112 |
+
|
113 |
+
def forward(self, x, length):
|
114 |
+
|
115 |
+
out = []
|
116 |
+
if "GCN_sim" in self.mode:
|
117 |
+
out.append(self.GCN_sim(x, length))
|
118 |
+
|
119 |
+
out = sum(out)
|
120 |
+
if self.skip:
|
121 |
+
out += x
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
model = GCN(512, 128, 512, 0.5, mode=[
|
128 |
+
"GCN_sim"], skip=True, num_layers=3, ST_n_next=3)
|
129 |
+
bs, T, N = 10, 5, 10
|
130 |
+
n_node = T*N
|
131 |
+
|
132 |
+
input = torch.rand(bs, n_node, 512)
|
133 |
+
length = torch.ones((bs))
|
134 |
+
length = length.type(torch.IntTensor)
|
135 |
+
bboxes = torch.rand((bs, 5, 10, 4))
|
136 |
+
|
137 |
+
output = model(input, length)
|
multimodal/build/lib/open_flamingo/src/helpers.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from https://github.com/lucidrains/flamingo-pytorch
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from einops_exts import rearrange_many
|
8 |
+
from torch import einsum, nn
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def FeedForward(dim, mult=4):
|
16 |
+
inner_dim = int(dim * mult)
|
17 |
+
return nn.Sequential(
|
18 |
+
nn.LayerNorm(dim),
|
19 |
+
nn.Linear(dim, inner_dim, bias=False),
|
20 |
+
nn.GELU(),
|
21 |
+
nn.Linear(inner_dim, dim, bias=False),
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class PerceiverAttention(nn.Module):
|
26 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
27 |
+
super().__init__()
|
28 |
+
self.scale = dim_head**-0.5
|
29 |
+
self.heads = heads
|
30 |
+
inner_dim = dim_head * heads
|
31 |
+
|
32 |
+
self.norm_media = nn.LayerNorm(dim)
|
33 |
+
self.norm_latents = nn.LayerNorm(dim)
|
34 |
+
|
35 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
36 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
37 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
38 |
+
|
39 |
+
def forward(self, x, latents):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
x (torch.Tensor): image features
|
43 |
+
shape (b, T, n1, D)
|
44 |
+
latent (torch.Tensor): latent features
|
45 |
+
shape (b, T, n2, D)
|
46 |
+
"""
|
47 |
+
x = self.norm_media(x)
|
48 |
+
latents = self.norm_latents(latents)
|
49 |
+
|
50 |
+
h = self.heads
|
51 |
+
|
52 |
+
q = self.to_q(latents)
|
53 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
54 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
55 |
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
56 |
+
q = q * self.scale
|
57 |
+
|
58 |
+
# attention
|
59 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
60 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
61 |
+
attn = sim.softmax(dim=-1)
|
62 |
+
|
63 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
64 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
65 |
+
return self.to_out(out)
|
66 |
+
|
67 |
+
|
68 |
+
class PerceiverResampler(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
*,
|
72 |
+
dim,
|
73 |
+
depth=6,
|
74 |
+
dim_head=64,
|
75 |
+
heads=8,
|
76 |
+
num_latents=64,
|
77 |
+
max_num_media=None,
|
78 |
+
max_num_frames=None,
|
79 |
+
ff_mult=4,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
assert False, "Do not use PerceiverResampler"
|
83 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
84 |
+
self.frame_embs = (
|
85 |
+
nn.Parameter(torch.randn(max_num_frames, dim))
|
86 |
+
if exists(max_num_frames)
|
87 |
+
else None
|
88 |
+
)
|
89 |
+
self.media_time_embs = (
|
90 |
+
nn.Parameter(torch.randn(max_num_media, 1, dim))
|
91 |
+
if exists(max_num_media)
|
92 |
+
else None
|
93 |
+
)
|
94 |
+
|
95 |
+
self.layers = nn.ModuleList([])
|
96 |
+
for _ in range(depth):
|
97 |
+
self.layers.append(
|
98 |
+
nn.ModuleList(
|
99 |
+
[
|
100 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
101 |
+
FeedForward(dim=dim, mult=ff_mult),
|
102 |
+
]
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
self.norm = nn.LayerNorm(dim)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
x (torch.Tensor): image features
|
112 |
+
shape (b, T, F, v, D)
|
113 |
+
Returns:
|
114 |
+
shape (b, T, n, D) where n is self.num_latents
|
115 |
+
"""
|
116 |
+
b, T, F, v = x.shape[:4]
|
117 |
+
|
118 |
+
# frame and media time embeddings
|
119 |
+
if exists(self.frame_embs):
|
120 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
121 |
+
x = x + frame_embs
|
122 |
+
x = rearrange(
|
123 |
+
x, "b T F v d -> b T (F v) d"
|
124 |
+
) # flatten the frame and spatial dimensions
|
125 |
+
if exists(self.media_time_embs):
|
126 |
+
x = x + self.media_time_embs[:T]
|
127 |
+
|
128 |
+
# blocks
|
129 |
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
130 |
+
for attn, ff in self.layers:
|
131 |
+
latents = attn(x, latents) + latents
|
132 |
+
latents = ff(latents) + latents
|
133 |
+
return self.norm(latents)
|
134 |
+
|
135 |
+
|
136 |
+
# gated cross attention
|
137 |
+
|
138 |
+
|
139 |
+
class MaskedCrossAttention(nn.Module):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
*,
|
143 |
+
dim,
|
144 |
+
dim_visual,
|
145 |
+
dim_head=64,
|
146 |
+
heads=8,
|
147 |
+
only_attend_immediate_media=True,
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
self.scale = dim_head**-0.5
|
151 |
+
self.heads = heads
|
152 |
+
inner_dim = dim_head * heads
|
153 |
+
|
154 |
+
self.norm = nn.LayerNorm(dim)
|
155 |
+
|
156 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
157 |
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
158 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
159 |
+
|
160 |
+
# whether for text to only attend to immediate preceding image, or all previous images
|
161 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
162 |
+
|
163 |
+
def forward(self, x, media, media_locations=None, attend_previous=True):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
x (torch.Tensor): text features
|
167 |
+
shape (B, T_txt, D_txt)
|
168 |
+
media (torch.Tensor): image features
|
169 |
+
shape (B, T_img, n, D_img) where n is the dim of the latents
|
170 |
+
media_locations: boolean mask identifying the media tokens in x
|
171 |
+
shape (B, T_txt)
|
172 |
+
attend_previous: bool
|
173 |
+
If false, ignores immediately preceding image and starts attending when following image
|
174 |
+
"""
|
175 |
+
assert attend_previous, "text must attend to the image that before it"
|
176 |
+
|
177 |
+
_, T_img, n = media.shape[:3]
|
178 |
+
h = self.heads
|
179 |
+
|
180 |
+
x = self.norm(x)
|
181 |
+
|
182 |
+
q = self.to_q(x)
|
183 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
184 |
+
|
185 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
186 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
187 |
+
|
188 |
+
q = q * self.scale
|
189 |
+
|
190 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
191 |
+
|
192 |
+
if exists(media_locations):
|
193 |
+
# at each boolean of True, increment the time counter (relative to media time)
|
194 |
+
text_time = media_locations.cumsum(dim=-1)
|
195 |
+
media_time = torch.arange(T_img, device=x.device) + 1
|
196 |
+
|
197 |
+
if not attend_previous:
|
198 |
+
text_time[~media_locations] += 1
|
199 |
+
# make sure max is still the number of images in the sequence
|
200 |
+
text_time[
|
201 |
+
text_time
|
202 |
+
> repeat(
|
203 |
+
torch.count_nonzero(media_locations, dim=1),
|
204 |
+
"b -> b i",
|
205 |
+
i=text_time.shape[1],
|
206 |
+
)
|
207 |
+
] = 0
|
208 |
+
|
209 |
+
# text time must equal media time if only attending to most immediate image
|
210 |
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
211 |
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
212 |
+
|
213 |
+
text_to_media_mask = mask_op(
|
214 |
+
rearrange(text_time, "b i -> b 1 i 1"),
|
215 |
+
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
216 |
+
)
|
217 |
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
218 |
+
|
219 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
220 |
+
attn = sim.softmax(dim=-1)
|
221 |
+
|
222 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
223 |
+
# any text without a preceding media needs to have attention zeroed out
|
224 |
+
text_without_media_mask = text_time == 0
|
225 |
+
text_without_media_mask = rearrange(
|
226 |
+
text_without_media_mask, "b i -> b 1 i 1"
|
227 |
+
)
|
228 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
229 |
+
|
230 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
231 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
232 |
+
return self.to_out(out)
|
233 |
+
|
234 |
+
|
235 |
+
class GatedCrossAttentionBlock(nn.Module):
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
*,
|
239 |
+
dim,
|
240 |
+
dim_visual,
|
241 |
+
dim_head=64,
|
242 |
+
heads=8,
|
243 |
+
ff_mult=4,
|
244 |
+
only_attend_immediate_media=True,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
self.attn = MaskedCrossAttention(
|
248 |
+
dim=dim,
|
249 |
+
dim_visual=dim_visual,
|
250 |
+
dim_head=dim_head,
|
251 |
+
heads=heads,
|
252 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
253 |
+
)
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
x,
|
258 |
+
media,
|
259 |
+
media_locations=None,
|
260 |
+
attend_previous=True,
|
261 |
+
):
|
262 |
+
x = self.attn(x, media, media_locations=media_locations, attend_previous=attend_previous) + x
|
263 |
+
return x
|
multimodal/build/lib/open_flamingo/src/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def extend_instance(obj, mixin):
|
2 |
+
"""Apply mixins to a class instance after creation"""
|
3 |
+
base_cls = obj.__class__
|
4 |
+
base_cls_name = obj.__class__.__name__
|
5 |
+
obj.__class__ = type(
|
6 |
+
base_cls_name, (mixin, base_cls), {}
|
7 |
+
) # mixin needs to go first for our forward() logic to work
|
8 |
+
|
9 |
+
|
10 |
+
def getattr_recursive(obj, att):
|
11 |
+
"""
|
12 |
+
Return nested attribute of obj
|
13 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
14 |
+
"""
|
15 |
+
if att == "":
|
16 |
+
return obj
|
17 |
+
i = att.find(".")
|
18 |
+
if i < 0:
|
19 |
+
return getattr(obj, att)
|
20 |
+
else:
|
21 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
22 |
+
|
23 |
+
|
24 |
+
def setattr_recursive(obj, att, val):
|
25 |
+
"""
|
26 |
+
Set nested attribute of obj
|
27 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
28 |
+
"""
|
29 |
+
if "." in att:
|
30 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
31 |
+
setattr(obj, att.split(".")[-1], val)
|
multimodal/build/lib/open_flamingo/train/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
multimodal/build/lib/open_flamingo/train/data2.py
ADDED
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from multiprocessing import Value
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import pickle as pkl
|
12 |
+
from open_flamingo.train.instruction_template import (
|
13 |
+
VG_RELATION_TEMPLATES,
|
14 |
+
PISC_TEMPLATES,
|
15 |
+
)
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import webdataset as wds
|
19 |
+
from PIL import Image
|
20 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
21 |
+
from torch.utils.data.distributed import DistributedSampler
|
22 |
+
from webdataset.tariterators import (
|
23 |
+
base_plus_ext,
|
24 |
+
tar_file_expander,
|
25 |
+
url_opener,
|
26 |
+
valid_sample,
|
27 |
+
)
|
28 |
+
|
29 |
+
from groundingdino.demo.caption_grounder import caption_grounder
|
30 |
+
from groundingdino.demo.inference_on_laion import add_loc_to_text
|
31 |
+
from groundingdino.demo.inference_on_laion import nms_without_score
|
32 |
+
from groundingdino.demo.inference_on_laion import calculate_iou
|
33 |
+
|
34 |
+
Image.MAX_IMAGE_PIXELS = 1000000000
|
35 |
+
LAION2B_NUM_SAMPLE = 1500000000
|
36 |
+
VQAV2_TRAIN_NUM_SAMPLE = 1828467
|
37 |
+
VG_RELATION_BBOX_SIZE = 600
|
38 |
+
|
39 |
+
REL_LABELS = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']
|
40 |
+
|
41 |
+
try:
|
42 |
+
import horovod.torch as hvd
|
43 |
+
except ImportError:
|
44 |
+
hvd = None
|
45 |
+
|
46 |
+
class ConcatDataset(IterableDataset):
|
47 |
+
def __init__(
|
48 |
+
self, dataset, max_length,
|
49 |
+
delimiter_id, pad_id=None, media_id=None, endofmedia_id=None,
|
50 |
+
image_embedding_size=-2, single=False, box_id=None, visual_id=None,
|
51 |
+
):
|
52 |
+
self.dataset = dataset
|
53 |
+
self.max_length = max_length
|
54 |
+
self.delimiter_id = torch.ones(1,1).long() * delimiter_id
|
55 |
+
if pad_id is not None:
|
56 |
+
self.pad_id = int(pad_id)
|
57 |
+
if media_id is not None:
|
58 |
+
self.media_id = torch.ones(1,1).long() * int(media_id)
|
59 |
+
if endofmedia_id is not None:
|
60 |
+
self.endofmedia_id = torch.ones(1,1).long() * int(endofmedia_id)
|
61 |
+
if image_embedding_size > 0:
|
62 |
+
logging.info(f"image_embedding_size: {image_embedding_size}")
|
63 |
+
self.image_embedding_size = image_embedding_size + 2
|
64 |
+
self.single = single
|
65 |
+
self.box_id = box_id
|
66 |
+
self.visual_id = visual_id
|
67 |
+
|
68 |
+
def __iter__(self):
|
69 |
+
while True:
|
70 |
+
input_ids_list = []
|
71 |
+
attention_mask_list = []
|
72 |
+
image_list = []
|
73 |
+
image_start_index_list = []
|
74 |
+
added_bbox_list = []
|
75 |
+
relations_list = []
|
76 |
+
cnt = 0
|
77 |
+
while cnt < self.max_length:
|
78 |
+
sample = next(self.dataset)
|
79 |
+
if len(sample) >= 4:
|
80 |
+
image = sample[0].unsqueeze(0)
|
81 |
+
input_ids = sample[1]
|
82 |
+
attention_mask = sample[2]
|
83 |
+
added_bbox = sample[3]
|
84 |
+
image_list.append(image)
|
85 |
+
added_bbox_list.append(added_bbox)
|
86 |
+
if len(sample) == 5:
|
87 |
+
relations_list.append(sample[4])
|
88 |
+
else:
|
89 |
+
sample = sample[0]
|
90 |
+
input_ids = sample[0]
|
91 |
+
attention_mask = sample[1]
|
92 |
+
input_ids_list.append(input_ids)
|
93 |
+
attention_mask_list.append(attention_mask)
|
94 |
+
cnt += input_ids.shape[-1]
|
95 |
+
if self.single:
|
96 |
+
break
|
97 |
+
input_ids = torch.cat(input_ids_list, dim=-1)[0]
|
98 |
+
attention_mask = torch.cat(attention_mask_list, dim=-1)[0]
|
99 |
+
if not self.single:
|
100 |
+
input_ids = input_ids[:self.max_length]
|
101 |
+
attention_mask = attention_mask[:self.max_length]
|
102 |
+
# TODO: fix visual number not match
|
103 |
+
if len(image_list) != 0:
|
104 |
+
images = torch.cat(image_list, dim=0)
|
105 |
+
image_begin = (input_ids == self.media_id[0,0]).nonzero().view(-1)
|
106 |
+
image_end = (input_ids == self.endofmedia_id[0,0]).nonzero().view(-1)
|
107 |
+
if len(image_begin) != len(image_end):
|
108 |
+
assert len(image_begin) == len(image_end) + 1
|
109 |
+
input_ids[image_begin[-1]:] = self.pad_id
|
110 |
+
attention_mask[image_begin[-1]:] = 0
|
111 |
+
image_begin = image_begin[:-1]
|
112 |
+
eos_token_num = len((input_ids == self.delimiter_id[0,0]).nonzero().view(-1))
|
113 |
+
if eos_token_num != len(image_begin) + 1:
|
114 |
+
input_ids[image_begin[-1]:] = self.pad_id
|
115 |
+
attention_mask[image_begin[-1]:] = 0
|
116 |
+
image_begin = image_begin[:-1]
|
117 |
+
image_end = image_end[:-1]
|
118 |
+
images = images[:len(image_end)]
|
119 |
+
added_bbox_list = added_bbox_list[:len(image_end)]
|
120 |
+
relations_list = relations_list[:len(image_end)]
|
121 |
+
image_start_index_list = (image_begin + 1).tolist()
|
122 |
+
expand_list = added_bbox_list[0]
|
123 |
+
for x in added_bbox_list[1:]:
|
124 |
+
expand_list.extend(x)
|
125 |
+
yield images, len(images), image_start_index_list, input_ids, attention_mask, expand_list, relations_list
|
126 |
+
else:
|
127 |
+
yield input_ids, attention_mask
|
128 |
+
|
129 |
+
|
130 |
+
class SharedEpoch:
|
131 |
+
def __init__(self, epoch: int = 0):
|
132 |
+
self.shared_epoch = Value("i", epoch)
|
133 |
+
|
134 |
+
def set_value(self, epoch):
|
135 |
+
self.shared_epoch.value = epoch
|
136 |
+
|
137 |
+
def get_value(self):
|
138 |
+
return self.shared_epoch.value
|
139 |
+
|
140 |
+
|
141 |
+
@dataclass
|
142 |
+
class DataInfo:
|
143 |
+
dataloader: DataLoader
|
144 |
+
sampler: DistributedSampler = None
|
145 |
+
shared_epoch: SharedEpoch = None
|
146 |
+
|
147 |
+
def set_epoch(self, epoch):
|
148 |
+
if self.shared_epoch is not None:
|
149 |
+
self.shared_epoch.set_value(epoch)
|
150 |
+
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
|
151 |
+
self.sampler.set_epoch(epoch)
|
152 |
+
|
153 |
+
|
154 |
+
def filter_no_caption_or_no_image(sample):
|
155 |
+
return ("txt" in sample) and (
|
156 |
+
"png" in sample or "jpg" in sample or "jpeg" in sample
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
def log_and_continue(exn):
|
161 |
+
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
|
162 |
+
if "ValueError" in repr(exn) or "KeyError" in repr(exn): # Avoid spamming logs with these
|
163 |
+
return True
|
164 |
+
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
165 |
+
return True
|
166 |
+
# DEBUG
|
167 |
+
# log_and_continue = None
|
168 |
+
# DEBUG
|
169 |
+
|
170 |
+
|
171 |
+
def group_by_keys_nothrow(
|
172 |
+
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
|
173 |
+
):
|
174 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
175 |
+
|
176 |
+
:param keys: function that splits the key into key and extension (base_plus_ext)
|
177 |
+
:param lcase: convert suffixes to lower case (Default value = True)
|
178 |
+
"""
|
179 |
+
current_sample = None
|
180 |
+
tar_idx = None
|
181 |
+
for filesample in data:
|
182 |
+
assert isinstance(filesample, dict)
|
183 |
+
current_tar_idx = filesample["__url__"].split("/")[-1].split(".")[0]
|
184 |
+
if current_tar_idx != tar_idx:
|
185 |
+
tar_idx = current_tar_idx
|
186 |
+
if "blip2_all_data_ground" in filesample["__url__"]:
|
187 |
+
relation_data_dir = os.path.join("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_relation", tar_idx)
|
188 |
+
missing_file = False
|
189 |
+
try:
|
190 |
+
data_info = pkl.load(open(os.path.join(relation_data_dir, "custom_data_info.pkl"), "rb"))
|
191 |
+
prediction = pkl.load(open(os.path.join(relation_data_dir, "custom_prediction.pkl"), "rb"))
|
192 |
+
idx_to_files = data_info["idx_to_files"]
|
193 |
+
ind_to_classes = data_info["ind_to_classes"]
|
194 |
+
ind_to_predicates = data_info["ind_to_predicates"]
|
195 |
+
files_to_idx = {x.split("#")[-1]: i for i, x in enumerate(idx_to_files)}
|
196 |
+
except:
|
197 |
+
missing_file = True
|
198 |
+
fname, value = filesample["fname"], filesample["data"]
|
199 |
+
prefix, suffix = keys(fname)
|
200 |
+
if prefix is None:
|
201 |
+
continue
|
202 |
+
if lcase:
|
203 |
+
suffix = suffix.lower()
|
204 |
+
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
205 |
+
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
206 |
+
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
207 |
+
if (
|
208 |
+
current_sample is None
|
209 |
+
or prefix != current_sample["__key__"]
|
210 |
+
or suffix in current_sample
|
211 |
+
):
|
212 |
+
if valid_sample(current_sample):
|
213 |
+
yield current_sample
|
214 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
215 |
+
if "blip2_all_data_ground" in filesample["__url__"] and not missing_file:
|
216 |
+
try:
|
217 |
+
idx = files_to_idx[prefix]
|
218 |
+
prediction[idx]["bbox"] = [np.array(bbox)/VG_RELATION_BBOX_SIZE for bbox in prediction[idx]["bbox"]]
|
219 |
+
current_sample["relation_data"] = prediction[idx]
|
220 |
+
except:
|
221 |
+
current_sample["relation_data"] = dict()
|
222 |
+
else:
|
223 |
+
current_sample["relation_data"] = dict()
|
224 |
+
if suffixes is None or suffix in suffixes:
|
225 |
+
current_sample[suffix] = value
|
226 |
+
if valid_sample(current_sample):
|
227 |
+
yield current_sample
|
228 |
+
|
229 |
+
|
230 |
+
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
|
231 |
+
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
232 |
+
streams = url_opener(src, handler=handler)
|
233 |
+
files = tar_file_expander(streams, handler=handler)
|
234 |
+
samples = group_by_keys_nothrow(files, handler=handler)
|
235 |
+
return samples
|
236 |
+
|
237 |
+
|
238 |
+
def pytorch_worker_seed(increment=0):
|
239 |
+
"""get dataloader worker seed from pytorch"""
|
240 |
+
worker_info = get_worker_info()
|
241 |
+
if worker_info is not None:
|
242 |
+
# favour using the seed already created for pytorch dataloader workers if it exists
|
243 |
+
seed = worker_info.seed
|
244 |
+
if increment:
|
245 |
+
# space out seed increments so they can't overlap across workers in different iterations
|
246 |
+
seed += increment * max(1, worker_info.num_workers)
|
247 |
+
return seed
|
248 |
+
# fallback to wds rank based seed
|
249 |
+
return wds.utils.pytorch_worker_seed()
|
250 |
+
|
251 |
+
|
252 |
+
_SHARD_SHUFFLE_SIZE = 2000
|
253 |
+
_SHARD_SHUFFLE_INITIAL = 500
|
254 |
+
_SAMPLE_SHUFFLE_SIZE = 5000
|
255 |
+
_SAMPLE_SHUFFLE_INITIAL = 1000
|
256 |
+
|
257 |
+
|
258 |
+
class ResampledShards2(IterableDataset):
|
259 |
+
"""An iterable dataset yielding a list of urls."""
|
260 |
+
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
urls,
|
264 |
+
nshards=sys.maxsize,
|
265 |
+
worker_seed=None,
|
266 |
+
deterministic=False,
|
267 |
+
epoch=-1,
|
268 |
+
):
|
269 |
+
"""Sample shards from the shard list with replacement.
|
270 |
+
:param urls: a list of URLs as a Python list or brace notation string
|
271 |
+
"""
|
272 |
+
super().__init__()
|
273 |
+
urls = wds.shardlists.expand_urls(urls)
|
274 |
+
self.urls = urls
|
275 |
+
assert isinstance(self.urls[0], str)
|
276 |
+
self.nshards = nshards
|
277 |
+
self.rng = random.Random()
|
278 |
+
self.worker_seed = worker_seed
|
279 |
+
self.deterministic = deterministic
|
280 |
+
self.epoch = epoch
|
281 |
+
|
282 |
+
def __iter__(self):
|
283 |
+
"""Return an iterator over the shards."""
|
284 |
+
if isinstance(self.epoch, SharedEpoch):
|
285 |
+
epoch = self.epoch.get_value()
|
286 |
+
else:
|
287 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
288 |
+
# situation as different workers may wrap at different times (or not at all).
|
289 |
+
self.epoch += 1
|
290 |
+
epoch = self.epoch
|
291 |
+
|
292 |
+
if self.deterministic:
|
293 |
+
# reset seed w/ epoch if deterministic
|
294 |
+
if self.worker_seed is None:
|
295 |
+
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
|
296 |
+
seed = pytorch_worker_seed(epoch)
|
297 |
+
else:
|
298 |
+
seed = self.worker_seed() + epoch
|
299 |
+
seed = seed + int(time.time())
|
300 |
+
self.rng.seed(seed)
|
301 |
+
# logging.info(f"epoch: {epoch} seed: {seed}")
|
302 |
+
self.rng.shuffle(self.urls)
|
303 |
+
# logging.info(f"{len(self.urls)} | {self.urls[:2]}")
|
304 |
+
for url in self.urls:
|
305 |
+
# logging.info(f"{seed}: {url}")
|
306 |
+
yield dict(url=url)
|
307 |
+
|
308 |
+
|
309 |
+
def preprocess_image(sample, image_processor):
|
310 |
+
image = image_processor(sample)
|
311 |
+
return image
|
312 |
+
|
313 |
+
|
314 |
+
def preprocess_text(sample, tokenizer, max_length, single=False):
|
315 |
+
if not single:
|
316 |
+
text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True)
|
317 |
+
else:
|
318 |
+
text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True, padding='max_length')
|
319 |
+
return text["input_ids"], text["attention_mask"]
|
320 |
+
|
321 |
+
|
322 |
+
def preprocess_encoded_text(sample, tokenizer, max_length):
|
323 |
+
sample = sample.decode("utf-8")
|
324 |
+
return preprocess_text(sample, tokenizer, max_length=max_length)
|
325 |
+
|
326 |
+
|
327 |
+
def _merge_bbox_previsual(added_bbox_list):
|
328 |
+
bbox_list = []
|
329 |
+
for bboxes in added_bbox_list:
|
330 |
+
x1 = bboxes[:, 0].min()
|
331 |
+
y1 = bboxes[:, 1].min()
|
332 |
+
x2 = bboxes[:, 2].max()
|
333 |
+
y2 = bboxes[:, 3].max()
|
334 |
+
bbox_list.append(torch.tensor([x1, y1, x2, y2], device=bboxes.device, dtype=bboxes.dtype).unsqueeze(0))
|
335 |
+
return bbox_list
|
336 |
+
|
337 |
+
|
338 |
+
def _find_idx(text, subtext):
|
339 |
+
loc = 0
|
340 |
+
locs = []
|
341 |
+
while text.find(subtext, loc) != -1:
|
342 |
+
loc = text.find(subtext, loc)
|
343 |
+
locs.append(loc)
|
344 |
+
loc += len(subtext)
|
345 |
+
return locs
|
346 |
+
|
347 |
+
def preprocess_ground_caption(sample, image_processor, tokenizer, image_embedding_size, generator, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None, args=None):
|
348 |
+
assert max_length is not None
|
349 |
+
assert not single, "single is not supported for preprocess_ground_caption"
|
350 |
+
image, caption, logits_filt, boxes_filt, relation_data = sample
|
351 |
+
if len(logits_filt.shape) == 1 and logits_filt.shape[0] == 4 and len(boxes_filt.shape) == 1 and boxes_filt.shape[0] == 4:
|
352 |
+
raise NotImplementedError # lack relation data
|
353 |
+
return preprocess_visual_genome(sample=sample, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=image_embedding_size, prob_ground=prob_ground, single=single, use_format_v2=use_format_v2, add_visual_token=add_visual_token, max_length=max_length)
|
354 |
+
image = preprocess_image(image, image_processor=image_processor)
|
355 |
+
added_bbox = []
|
356 |
+
if (prob_ground != 0 and random.random() <= prob_ground) or prob_ground == 1.0:
|
357 |
+
boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
|
358 |
+
caption, added_bbox = add_loc_to_text(
|
359 |
+
boxes_filt, pred_phrases, caption,
|
360 |
+
expand=args.expand, always_expand=args.longer_previsual,
|
361 |
+
)
|
362 |
+
visual_loc = []
|
363 |
+
obj_loc = []
|
364 |
+
endofobj_loc = []
|
365 |
+
visual_token = "<|#visual#|>"
|
366 |
+
previsual_token = "<|#previsual#|>"
|
367 |
+
box_token = "<|#box#|>"
|
368 |
+
prebox_token = "<|#prebox#|>"
|
369 |
+
end_token = "<|#endofobject#|>"
|
370 |
+
object_token = "<|#object#|>"
|
371 |
+
end_of_attr_token = "<|#endofattr#|>"
|
372 |
+
preend_of_attr_token = "<|#preendofattr#|>"
|
373 |
+
visual_loc = _find_idx(caption, visual_token)
|
374 |
+
try:
|
375 |
+
if len(visual_loc) != len(added_bbox):
|
376 |
+
logging.warning(f"visual_loc: {visual_loc}")
|
377 |
+
logging.warning(f"added_bbox: {added_bbox}")
|
378 |
+
except:
|
379 |
+
pass
|
380 |
+
assert len(visual_loc) == len(added_bbox)
|
381 |
+
delta = 0
|
382 |
+
for i, (loc, boxes) in enumerate(zip(visual_loc, added_bbox)):
|
383 |
+
loc += delta
|
384 |
+
boxes = nms_without_score(boxes)
|
385 |
+
added_bbox[i] = boxes
|
386 |
+
added_tokens = end_token + visual_token + box_token * len(boxes) + end_of_attr_token
|
387 |
+
caption = caption[:loc] + added_tokens + caption[len(visual_token) + loc:]
|
388 |
+
delta += len(added_tokens) - len(visual_token)
|
389 |
+
|
390 |
+
if use_format_v2:
|
391 |
+
merge_added_bbox = _merge_bbox_previsual(added_bbox)
|
392 |
+
# step 1: move <|#object#|> before the space char
|
393 |
+
while caption.find(f" {object_token}") != -1:
|
394 |
+
caption = caption.replace(f" {object_token}", f"{object_token} ")
|
395 |
+
# step 2: add <|#previsual#|> after <|#object#|> for 75% except the first object
|
396 |
+
i = 0
|
397 |
+
II = -1
|
398 |
+
if args.no_visual:
|
399 |
+
flag = False
|
400 |
+
delete_visual_prob = 10.0
|
401 |
+
else:
|
402 |
+
flag = True
|
403 |
+
delete_visual_prob = 0.75
|
404 |
+
while i < len(caption):
|
405 |
+
if caption[i: i + len(object_token)] == object_token:
|
406 |
+
II += 1
|
407 |
+
if (not args.longer_previsual and not flag and random.random() < delete_visual_prob) or (args.longer_previsual and (flag or random.random() < delete_visual_prob)):
|
408 |
+
# delete visual and add previsual
|
409 |
+
visual_start_idx = caption.find(end_token, i+1) + len(end_token)
|
410 |
+
visual_end_idx = caption.find(end_of_attr_token, visual_start_idx+1) + len(end_of_attr_token)
|
411 |
+
caption = caption[:visual_start_idx] + caption[visual_end_idx:]
|
412 |
+
caption = caption[:i + len(object_token)] + previsual_token + prebox_token + preend_of_attr_token + caption[i + len(object_token):]
|
413 |
+
added_bbox[II] = merge_added_bbox[II]
|
414 |
+
i += 1
|
415 |
+
flag = False
|
416 |
+
if args.no_previsual and args.no_visual:
|
417 |
+
caption = caption.replace(previsual_token, "").replace(prebox_token, "").replace(preend_of_attr_token, "")
|
418 |
+
added_bbox = []
|
419 |
+
caption = caption.replace(preend_of_attr_token, object_token).replace(end_of_attr_token, end_token)
|
420 |
+
|
421 |
+
|
422 |
+
if args.roi_align:
|
423 |
+
i = 0
|
424 |
+
pad_num = args.roi_output_size ** 2 - 1
|
425 |
+
while i < len(caption):
|
426 |
+
if caption[i: i + len(prebox_token)] == prebox_token:
|
427 |
+
caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:]
|
428 |
+
i += len(tokenizer.pad_token) * pad_num + len(prebox_token)
|
429 |
+
elif caption[i: i + len(box_token)] == box_token:
|
430 |
+
caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:]
|
431 |
+
i += len(tokenizer.pad_token) * pad_num + len(box_token)
|
432 |
+
i += 1
|
433 |
+
|
434 |
+
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
|
435 |
+
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
|
436 |
+
relations = []
|
437 |
+
if args.only_grounded_sample and "<|#visual#|>" not in caption:
|
438 |
+
raise ValueError
|
439 |
+
return image, input_ids, attention_mask, added_bbox, relations
|
440 |
+
|
441 |
+
|
442 |
+
def preprocess_visual_genome(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
|
443 |
+
assert max_length is not None
|
444 |
+
assert not single, "single is not supported for preprocess_ground_caption"
|
445 |
+
image, caption, xyxy, _ = sample
|
446 |
+
image = preprocess_image(image, image_processor=image_processor)
|
447 |
+
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|><|#object#|>" + caption.strip() + "<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
|
448 |
+
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
|
449 |
+
added_bbox = [torch.tensor(np.expand_dims(xyxy, 0).astype(np.float32) / 224)]
|
450 |
+
return image, input_ids, attention_mask, added_bbox
|
451 |
+
|
452 |
+
special_predicate = [
|
453 |
+
"and",
|
454 |
+
"has",
|
455 |
+
"says",
|
456 |
+
"wears",
|
457 |
+
]
|
458 |
+
|
459 |
+
original_predicate = {
|
460 |
+
"and": "and",
|
461 |
+
"has": "have",
|
462 |
+
"says": "say",
|
463 |
+
"wears": "wear",
|
464 |
+
}
|
465 |
+
|
466 |
+
|
467 |
+
def generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation):
|
468 |
+
if relation in ["and", "of"]:
|
469 |
+
id = 0
|
470 |
+
else:
|
471 |
+
id = random.choice(range(len(VG_RELATION_TEMPLATES)))
|
472 |
+
text = VG_RELATION_TEMPLATES[id].format(nameA=nameA, nameB=nameB, relation=relation, use_is="is" if relation not in special_predicate else "", is_or_does="is" if relation not in special_predicate else "does", relation_do=relation if relation not in special_predicate else original_predicate[relation])
|
473 |
+
if id in [0]:
|
474 |
+
added_bbox = [
|
475 |
+
torch.tensor([boxA]),
|
476 |
+
torch.tensor([boxB]),
|
477 |
+
]
|
478 |
+
elif id in [1]:
|
479 |
+
added_bbox = [
|
480 |
+
torch.tensor([boxA]),
|
481 |
+
torch.tensor([boxB]),
|
482 |
+
torch.tensor([boxA]),
|
483 |
+
torch.tensor([boxB]),
|
484 |
+
]
|
485 |
+
elif id in [2]:
|
486 |
+
added_bbox = [
|
487 |
+
torch.tensor([boxA]),
|
488 |
+
torch.tensor([boxA]),
|
489 |
+
torch.tensor([boxB]),
|
490 |
+
]
|
491 |
+
elif id in [3]:
|
492 |
+
added_bbox = [
|
493 |
+
torch.tensor([boxB]),
|
494 |
+
torch.tensor([boxA]),
|
495 |
+
torch.tensor([boxB]),
|
496 |
+
]
|
497 |
+
elif id in [4]:
|
498 |
+
added_bbox = [
|
499 |
+
torch.tensor([boxA]),
|
500 |
+
torch.tensor([boxB]),
|
501 |
+
]
|
502 |
+
elif id in [5]:
|
503 |
+
added_bbox = [
|
504 |
+
torch.tensor([boxB]),
|
505 |
+
torch.tensor([boxA]),
|
506 |
+
]
|
507 |
+
else:
|
508 |
+
raise NotImplementedError
|
509 |
+
return text, added_bbox
|
510 |
+
|
511 |
+
def generate_pisc_sample(boxA, boxB, relation):
|
512 |
+
id = random.choice(range(len(PISC_TEMPLATES)))
|
513 |
+
text = PISC_TEMPLATES[id].format(relation=relation)
|
514 |
+
if id in [0]:
|
515 |
+
if random.random() < 0.5:
|
516 |
+
added_bbox = [
|
517 |
+
torch.tensor([boxA]),
|
518 |
+
torch.tensor([boxB]),
|
519 |
+
]
|
520 |
+
else:
|
521 |
+
added_bbox = [
|
522 |
+
torch.tensor([boxB]),
|
523 |
+
torch.tensor([boxA]),
|
524 |
+
]
|
525 |
+
elif id in [1]:
|
526 |
+
if random.random() < 0.5:
|
527 |
+
added_bbox = [torch.tensor([boxA, boxB])]
|
528 |
+
else:
|
529 |
+
added_bbox = [torch.tensor([boxB, boxA])]
|
530 |
+
return text, added_bbox
|
531 |
+
|
532 |
+
|
533 |
+
def preprocess_instruct(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
|
534 |
+
image_path, dataset, data = sample
|
535 |
+
image = Image.open(image_path)
|
536 |
+
size = image_processor.transforms[0].size
|
537 |
+
image = image.resize((size, size))
|
538 |
+
if dataset == "pisc_relation_split":
|
539 |
+
boxA = data[0]
|
540 |
+
boxB = data[1]
|
541 |
+
relation = data[2]
|
542 |
+
text, added_bbox = generate_pisc_sample(boxA, boxB, relation)
|
543 |
+
# import cv2
|
544 |
+
# boxA *= size
|
545 |
+
# boxB *= size
|
546 |
+
# open_cv_image = np.array(image)
|
547 |
+
# open_cv_image = open_cv_image[:, :, ::-1].copy()
|
548 |
+
# open_cv_image = cv2.rectangle(open_cv_image, boxA[:2].astype(int), boxA[2:].astype(int), (255, 0, 0), 2)
|
549 |
+
# open_cv_image = cv2.rectangle(open_cv_image, boxB[:2].astype(int), boxB[2:].astype(int), (0, 255, 0), 2)
|
550 |
+
# cv2.imwrite("output.jpg", open_cv_image)
|
551 |
+
# import pdb; pdb.set_trace()
|
552 |
+
elif dataset == "vg_relation":
|
553 |
+
boxA = data[0][0]
|
554 |
+
nameA = data[0][1]
|
555 |
+
boxB = data[1][0]
|
556 |
+
nameB = data[1][1]
|
557 |
+
relation = data[2]
|
558 |
+
text, added_bbox = generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation)
|
559 |
+
image = preprocess_image(image, image_processor=image_processor)
|
560 |
+
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + text + tokenizer.eos_token
|
561 |
+
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=True)
|
562 |
+
# return image, input_ids, attention_mask, added_bbox
|
563 |
+
images = image.unsqueeze(0)
|
564 |
+
image_start_index_list = [2]
|
565 |
+
return images, len(images), image_start_index_list, input_ids, attention_mask, added_bbox
|
566 |
+
|
567 |
+
|
568 |
+
def preprocess_caption(sample, image_processor, tokenizer, image_embedding_size, max_length, single=False):
|
569 |
+
image, caption = sample
|
570 |
+
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
|
571 |
+
image = preprocess_image(image, image_processor=image_processor)
|
572 |
+
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=single)
|
573 |
+
return image, input_ids, attention_mask
|
574 |
+
|
575 |
+
|
576 |
+
def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
577 |
+
input_shards = args.pile_shards
|
578 |
+
assert input_shards is not None
|
579 |
+
resampled = getattr(args, "dataset_resampled", False)
|
580 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
581 |
+
|
582 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
583 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
584 |
+
preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer, max_length=args.max_length)
|
585 |
+
pipeline = [
|
586 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
587 |
+
tarfile_to_samples_nothrow,
|
588 |
+
wds.shuffle(
|
589 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
590 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
591 |
+
),
|
592 |
+
wds.to_tuple("txt", handler=log_and_continue),
|
593 |
+
wds.map_tuple(
|
594 |
+
preprocess_text_fn, handler=log_and_continue
|
595 |
+
),
|
596 |
+
]
|
597 |
+
# with_epoch(sys.maxsize) will give us an infinite sample stream
|
598 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
599 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
600 |
+
dataset = ConcatDataset(iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id)
|
601 |
+
|
602 |
+
|
603 |
+
def text_collate_fn(items):
|
604 |
+
try:
|
605 |
+
input_ids = torch.cat([x[0].unsqueeze(0) for x in items], dim=0)
|
606 |
+
attention_mask = torch.cat([x[1].unsqueeze(0) for x in items], dim=0)
|
607 |
+
return input_ids, attention_mask
|
608 |
+
except:
|
609 |
+
return None, None
|
610 |
+
|
611 |
+
dataloader = wds.WebLoader(
|
612 |
+
dataset,
|
613 |
+
batch_size=args.batch_size_pile,
|
614 |
+
shuffle=False,
|
615 |
+
num_workers=args.workers,
|
616 |
+
persistent_workers=False,
|
617 |
+
collate_fn=text_collate_fn,
|
618 |
+
)
|
619 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
620 |
+
|
621 |
+
|
622 |
+
# FIXME:
|
623 |
+
# modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433
|
624 |
+
# combine_tensors=True to combine_tensors=False
|
625 |
+
def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
626 |
+
input_shards = args.laion_shards
|
627 |
+
assert input_shards is not None
|
628 |
+
resampled = getattr(args, "dataset_resampled", False)
|
629 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
630 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
631 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
632 |
+
generator = caption_grounder(
|
633 |
+
config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
634 |
+
checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
635 |
+
cpu_only=True,
|
636 |
+
# box_threshold=0.5, text_threshold=0.3,
|
637 |
+
)
|
638 |
+
preprocess_ground_caption_fn = functools.partial(
|
639 |
+
preprocess_ground_caption, image_processor=image_processor, tokenizer=tokenizer,
|
640 |
+
image_embedding_size=args.vis_embed_size, single=args.single, generator=generator,
|
641 |
+
prob_ground=args.prob_ground, use_format_v2=args.use_format_v2,
|
642 |
+
add_visual_token=args.add_visual_token, max_length=args.max_length,
|
643 |
+
args=args,
|
644 |
+
)
|
645 |
+
pipeline = [
|
646 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
647 |
+
tarfile_to_samples_nothrow,
|
648 |
+
wds.shuffle(
|
649 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
650 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
651 |
+
),
|
652 |
+
wds.select(filter_no_caption_or_no_image),
|
653 |
+
wds.decode("pilrgb", partial=True, handler=log_and_continue),
|
654 |
+
wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", "relation_data", handler=log_and_continue),
|
655 |
+
wds.map(
|
656 |
+
preprocess_ground_caption_fn, handler=log_and_continue
|
657 |
+
),
|
658 |
+
]
|
659 |
+
|
660 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
661 |
+
# for sample in dataset:
|
662 |
+
# print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
|
663 |
+
# DEBUG
|
664 |
+
# dataset = wds.DataPipeline(*pipeline)
|
665 |
+
# from tqdm import tqdm
|
666 |
+
# for sample in tqdm(dataset):
|
667 |
+
# nn = 0
|
668 |
+
# for x in sample[1][0]:
|
669 |
+
# if x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]:
|
670 |
+
# nn += 1
|
671 |
+
# if x == tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]:
|
672 |
+
# nn -= 1
|
673 |
+
# if nn not in [0, 1]:
|
674 |
+
# print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
|
675 |
+
# import pdb; pdb.set_trace()
|
676 |
+
# if nn != 0:
|
677 |
+
# print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
|
678 |
+
# import pdb; pdb.set_trace()
|
679 |
+
# from groundingdino.demo.inference_on_laion import OBJ_LENGTHS
|
680 |
+
# # import pdb; pdb.set_trace()
|
681 |
+
# print(sum(OBJ_LENGTHS) / len(OBJ_LENGTHS))
|
682 |
+
# exit()
|
683 |
+
# DEBUG
|
684 |
+
|
685 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
686 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
687 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
688 |
+
box_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
689 |
+
visual_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
690 |
+
dataset = ConcatDataset(
|
691 |
+
iter(dataset), max_length=args.max_length,
|
692 |
+
delimiter_id=delimiter_id,
|
693 |
+
pad_id=tokenizer.pad_token_id,
|
694 |
+
media_id=media_token_id,
|
695 |
+
endofmedia_id=endofmedia_token_id,
|
696 |
+
box_id=box_id,
|
697 |
+
visual_id=visual_id,
|
698 |
+
image_embedding_size=args.vis_embed_size,
|
699 |
+
single=args.single,
|
700 |
+
)
|
701 |
+
|
702 |
+
def image_collate_fn(items):
|
703 |
+
images = torch.cat([x[0] for x in items], dim=0)
|
704 |
+
image_nums = [x[1] for x in items]
|
705 |
+
image_start_index_list = [x[2] for x in items]
|
706 |
+
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
|
707 |
+
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
|
708 |
+
added_bbox_list = [x[5] for x in items]
|
709 |
+
expand_list = added_bbox_list[0]
|
710 |
+
for x in added_bbox_list[1:]:
|
711 |
+
expand_list.extend(x)
|
712 |
+
relations_list = [x[6] for x in items]
|
713 |
+
return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list, relations_list
|
714 |
+
|
715 |
+
dataloader = wds.WebLoader(
|
716 |
+
dataset,
|
717 |
+
batch_size=args.batch_size_laion,
|
718 |
+
shuffle=False,
|
719 |
+
num_workers=args.workers,
|
720 |
+
persistent_workers=False,
|
721 |
+
collate_fn=image_collate_fn,
|
722 |
+
)
|
723 |
+
round_fn = math.floor if floor else math.ceil
|
724 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
725 |
+
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
|
726 |
+
dataloader.num_batches = num_batches
|
727 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
728 |
+
|
729 |
+
|
730 |
+
def get_image_text_pair_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
731 |
+
input_shards = args.laion_shards
|
732 |
+
assert input_shards is not None
|
733 |
+
resampled = getattr(args, "dataset_resampled", False)
|
734 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
735 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
736 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
737 |
+
preprocess_caption_fn = functools.partial(
|
738 |
+
preprocess_caption, image_processor=image_processor, tokenizer=tokenizer,
|
739 |
+
image_embedding_size=args.vis_embed_size, single=args.single,
|
740 |
+
max_length=args.max_length,
|
741 |
+
)
|
742 |
+
pipeline = [
|
743 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
744 |
+
tarfile_to_samples_nothrow,
|
745 |
+
wds.shuffle(
|
746 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
747 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
748 |
+
),
|
749 |
+
wds.select(filter_no_caption_or_no_image),
|
750 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
751 |
+
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
|
752 |
+
wds.map(
|
753 |
+
preprocess_caption_fn, handler=log_and_continue
|
754 |
+
),
|
755 |
+
]
|
756 |
+
|
757 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
758 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
759 |
+
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
|
760 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
761 |
+
dataset = ConcatDataset(
|
762 |
+
iter(dataset), max_length=args.max_length,
|
763 |
+
delimiter_id=delimiter_id,
|
764 |
+
pad_id=tokenizer.pad_token_id,
|
765 |
+
media_id=media_token_id,
|
766 |
+
endofmedia_id=endofmedia_token_id,
|
767 |
+
image_embedding_size=args.vis_embed_size,
|
768 |
+
single=args.single,
|
769 |
+
)
|
770 |
+
|
771 |
+
def image_collate_fn(items):
|
772 |
+
images = torch.cat([x[0] for x in items], dim=0)
|
773 |
+
image_nums = [x[1] for x in items]
|
774 |
+
image_start_index_list = [x[2] for x in items]
|
775 |
+
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
|
776 |
+
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
|
777 |
+
return images, image_nums, image_start_index_list, input_ids, attention_mask
|
778 |
+
|
779 |
+
dataloader = wds.WebLoader(
|
780 |
+
dataset,
|
781 |
+
batch_size=args.batch_size_laion,
|
782 |
+
shuffle=False,
|
783 |
+
num_workers=args.workers,
|
784 |
+
persistent_workers=False,
|
785 |
+
collate_fn=image_collate_fn,
|
786 |
+
)
|
787 |
+
round_fn = math.floor if floor else math.ceil
|
788 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
789 |
+
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
|
790 |
+
dataloader.num_batches = num_batches
|
791 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
792 |
+
|
793 |
+
|
794 |
+
def get_instruct_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
795 |
+
input_shards = args.laion_shards
|
796 |
+
assert input_shards is not None
|
797 |
+
resampled = getattr(args, "dataset_resampled", False)
|
798 |
+
assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
|
799 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
800 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
801 |
+
preprocess_instruct_fn = functools.partial(
|
802 |
+
preprocess_instruct, image_processor=image_processor, tokenizer=tokenizer,
|
803 |
+
image_embedding_size=args.vis_embed_size,
|
804 |
+
max_length=args.max_length,
|
805 |
+
)
|
806 |
+
pipeline = [
|
807 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
|
808 |
+
tarfile_to_samples_nothrow,
|
809 |
+
wds.shuffle(
|
810 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
811 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
812 |
+
),
|
813 |
+
wds.decode(partial=True),
|
814 |
+
wds.to_tuple("image_path.txt", "dataset.txt", "data.pyd", handler=log_and_continue),
|
815 |
+
wds.map(
|
816 |
+
preprocess_instruct_fn, handler=log_and_continue
|
817 |
+
),
|
818 |
+
]
|
819 |
+
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
|
820 |
+
|
821 |
+
def image_collate_fn(items):
|
822 |
+
images = torch.cat([x[0] for x in items], dim=0)
|
823 |
+
image_nums = [x[1] for x in items]
|
824 |
+
image_start_index_list = [x[2] for x in items]
|
825 |
+
input_ids = torch.cat([x[3] for x in items], dim=0)
|
826 |
+
attention_mask = torch.cat([x[4] for x in items], dim=0)
|
827 |
+
added_bbox_list = [x[5] for x in items]
|
828 |
+
expand_list = added_bbox_list[0]
|
829 |
+
for x in added_bbox_list[1:]:
|
830 |
+
expand_list.extend(x)
|
831 |
+
return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list
|
832 |
+
|
833 |
+
dataloader = wds.WebLoader(
|
834 |
+
dataset,
|
835 |
+
batch_size=args.batch_size_laion,
|
836 |
+
shuffle=False,
|
837 |
+
num_workers=args.workers,
|
838 |
+
persistent_workers=False,
|
839 |
+
collate_fn=image_collate_fn,
|
840 |
+
)
|
841 |
+
round_fn = math.floor if floor else math.ceil
|
842 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
843 |
+
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
|
844 |
+
dataloader.num_batches = num_batches
|
845 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
846 |
+
|
847 |
+
|
848 |
+
def get_dataset_fn(dataset_type):
|
849 |
+
if dataset_type == "mmc4":
|
850 |
+
raise NotImplementedError
|
851 |
+
elif dataset_type == "pile":
|
852 |
+
return get_pile_dataset
|
853 |
+
elif dataset_type == "ground_image_text":
|
854 |
+
return get_ground_laion_dataset
|
855 |
+
elif dataset_type == "image_text":
|
856 |
+
return get_image_text_pair_dataset
|
857 |
+
elif dataset_type == "vqav2":
|
858 |
+
raise NotImplementedError
|
859 |
+
elif dataset_type == "instruct":
|
860 |
+
return get_instruct_dataset
|
861 |
+
else:
|
862 |
+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
863 |
+
|
864 |
+
|
865 |
+
def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
|
866 |
+
return get_dataset_fn(dataset_type)(
|
867 |
+
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
|
868 |
+
)
|
multimodal/build/lib/open_flamingo/train/distributed.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
import horovod.torch as hvd
|
7 |
+
except ImportError:
|
8 |
+
hvd = None
|
9 |
+
|
10 |
+
|
11 |
+
def is_global_master(args):
|
12 |
+
return args.rank == 0
|
13 |
+
|
14 |
+
|
15 |
+
def is_local_master(args):
|
16 |
+
return args.local_rank == 0
|
17 |
+
|
18 |
+
|
19 |
+
def is_master(args, local=False):
|
20 |
+
return is_local_master(args) if local else is_global_master(args)
|
21 |
+
|
22 |
+
|
23 |
+
def is_using_horovod():
|
24 |
+
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
|
25 |
+
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
|
26 |
+
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
|
27 |
+
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
|
28 |
+
if all([var in os.environ for var in ompi_vars]) or all(
|
29 |
+
[var in os.environ for var in pmi_vars]
|
30 |
+
):
|
31 |
+
return True
|
32 |
+
else:
|
33 |
+
return False
|
34 |
+
|
35 |
+
|
36 |
+
def is_using_distributed():
|
37 |
+
if "WORLD_SIZE" in os.environ:
|
38 |
+
return int(os.environ["WORLD_SIZE"]) > 1
|
39 |
+
if "SLURM_NTASKS" in os.environ:
|
40 |
+
return int(os.environ["SLURM_NTASKS"]) > 1
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
def world_info_from_env():
|
45 |
+
local_rank = 0
|
46 |
+
for v in (
|
47 |
+
"LOCAL_RANK",
|
48 |
+
"MPI_LOCALRANKID",
|
49 |
+
"SLURM_LOCALID",
|
50 |
+
"OMPI_COMM_WORLD_LOCAL_RANK",
|
51 |
+
):
|
52 |
+
if v in os.environ:
|
53 |
+
local_rank = int(os.environ[v])
|
54 |
+
break
|
55 |
+
global_rank = 0
|
56 |
+
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
|
57 |
+
if v in os.environ:
|
58 |
+
global_rank = int(os.environ[v])
|
59 |
+
break
|
60 |
+
world_size = 1
|
61 |
+
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
|
62 |
+
if v in os.environ:
|
63 |
+
world_size = int(os.environ[v])
|
64 |
+
break
|
65 |
+
|
66 |
+
return local_rank, global_rank, world_size
|
67 |
+
|
68 |
+
|
69 |
+
def init_distributed_device(args):
|
70 |
+
# Distributed training = training on more than one GPU.
|
71 |
+
# Works in both single and multi-node scenarios.
|
72 |
+
args.distributed = False
|
73 |
+
args.world_size = 1
|
74 |
+
args.rank = 0 # global rank
|
75 |
+
args.local_rank = 0
|
76 |
+
if args.horovod:
|
77 |
+
assert hvd is not None, "Horovod is not installed"
|
78 |
+
hvd.init()
|
79 |
+
args.local_rank = int(hvd.local_rank())
|
80 |
+
args.rank = hvd.rank()
|
81 |
+
args.world_size = hvd.size()
|
82 |
+
args.distributed = True
|
83 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
84 |
+
os.environ["RANK"] = str(args.rank)
|
85 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
86 |
+
elif is_using_distributed():
|
87 |
+
if "SLURM_PROCID" in os.environ:
|
88 |
+
# DDP via SLURM
|
89 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
90 |
+
# SLURM var -> torch.distributed vars in case needed
|
91 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
92 |
+
os.environ["RANK"] = str(args.rank)
|
93 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
94 |
+
torch.distributed.init_process_group(
|
95 |
+
backend=args.dist_backend,
|
96 |
+
init_method=args.dist_url,
|
97 |
+
world_size=args.world_size,
|
98 |
+
rank=args.rank,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
# DDP via torchrun, torch.distributed.launch
|
102 |
+
args.local_rank, _, _ = world_info_from_env()
|
103 |
+
torch.distributed.init_process_group(
|
104 |
+
backend=args.dist_backend, init_method=args.dist_url
|
105 |
+
)
|
106 |
+
args.world_size = torch.distributed.get_world_size()
|
107 |
+
args.rank = torch.distributed.get_rank()
|
108 |
+
args.distributed = True
|
109 |
+
else:
|
110 |
+
# needed to run on single gpu
|
111 |
+
torch.distributed.init_process_group(
|
112 |
+
backend=args.dist_backend,
|
113 |
+
init_method=args.dist_url,
|
114 |
+
world_size=1,
|
115 |
+
rank=0,
|
116 |
+
)
|
117 |
+
|
118 |
+
if torch.cuda.is_available():
|
119 |
+
if args.distributed and not args.no_set_device_rank:
|
120 |
+
device = "cuda:%d" % args.local_rank
|
121 |
+
else:
|
122 |
+
device = "cuda:0"
|
123 |
+
torch.cuda.set_device(device)
|
124 |
+
else:
|
125 |
+
device = "cpu"
|
126 |
+
args.device = device
|
127 |
+
device = torch.device(device)
|
128 |
+
return device
|
multimodal/build/lib/open_flamingo/train/instruction_template.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
VG_RELATION_TEMPLATES = [
|
2 |
+
"Question: What is the relationship between<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
|
3 |
+
"Question: What is the relationship between<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
|
4 |
+
"Question: What {is_or_does}<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {relation_do}? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|>{nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
|
5 |
+
"Question: What {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
|
6 |
+
"Question: What {is_or_does}<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {relation_do}? Answer:<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
|
7 |
+
"Question: What {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
|
8 |
+
]
|
9 |
+
|
10 |
+
PISC_TEMPLATES = [
|
11 |
+
"Question: What is the social relationship between this<|#object#|> person<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and that<|#object#|> person<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
|
12 |
+
"Question: What is the social relationship between these<|#object#|> people<|#endofobject#|><|#visual#|><|#box#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
|
13 |
+
]
|
multimodal/build/lib/open_flamingo/train/train.py
ADDED
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Main training script """
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import copy
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import functools
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
# torch.multiprocessing.set_sharing_strategy('file_system')
|
13 |
+
import wandb
|
14 |
+
from data2 import get_data
|
15 |
+
from distributed import init_distributed_device, world_info_from_env
|
16 |
+
from torch.distributed.fsdp import (
|
17 |
+
FullyShardedDataParallel as FSDP,
|
18 |
+
MixedPrecision,
|
19 |
+
BackwardPrefetch,
|
20 |
+
ShardingStrategy,
|
21 |
+
FullStateDictConfig,
|
22 |
+
CPUOffload,
|
23 |
+
StateDictType,
|
24 |
+
)
|
25 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
26 |
+
from torch.distributed.fsdp.wrap import (
|
27 |
+
transformer_auto_wrap_policy,
|
28 |
+
enable_wrap,
|
29 |
+
wrap,
|
30 |
+
)
|
31 |
+
|
32 |
+
from train_utils import train_one_epoch
|
33 |
+
from transformers import (
|
34 |
+
get_constant_schedule_with_warmup,
|
35 |
+
get_cosine_schedule_with_warmup,
|
36 |
+
get_linear_schedule_with_warmup,
|
37 |
+
)
|
38 |
+
|
39 |
+
from open_flamingo import create_model_and_transforms
|
40 |
+
from torch.utils.tensorboard import SummaryWriter
|
41 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
42 |
+
from torch.cuda.amp import GradScaler
|
43 |
+
from torch.distributed.optim import ZeroRedundancyOptimizer
|
44 |
+
import warnings
|
45 |
+
warnings.filterwarnings("ignore")
|
46 |
+
import logging
|
47 |
+
logging.basicConfig(
|
48 |
+
level=logging.INFO,
|
49 |
+
format='%(asctime)s %(message)s',
|
50 |
+
datefmt='%m/%d %I:%M:%S',
|
51 |
+
)
|
52 |
+
|
53 |
+
class FakeDataloader:
|
54 |
+
def __iter__(self):
|
55 |
+
return self
|
56 |
+
|
57 |
+
def __next__(self):
|
58 |
+
return None
|
59 |
+
|
60 |
+
def random_seed(seed=42, rank=0):
|
61 |
+
torch.manual_seed(seed + rank)
|
62 |
+
np.random.seed(seed + rank)
|
63 |
+
random.seed(seed + rank)
|
64 |
+
|
65 |
+
|
66 |
+
def get_grouped_params(model, args):
|
67 |
+
params_with_wd, params_without_wd = [], []
|
68 |
+
|
69 |
+
def apply_decay(x):
|
70 |
+
x = x.lower()
|
71 |
+
return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x
|
72 |
+
|
73 |
+
for n, p in model.named_parameters():
|
74 |
+
# if p.requires_grad:
|
75 |
+
if apply_decay(n):
|
76 |
+
if torch.distributed.get_rank() == 0:
|
77 |
+
logging.info(f"with wd: {n}")
|
78 |
+
params_with_wd.append(p)
|
79 |
+
else:
|
80 |
+
if torch.distributed.get_rank() == 0:
|
81 |
+
logging.info(f"without wd: {n}")
|
82 |
+
params_without_wd.append(p)
|
83 |
+
return [
|
84 |
+
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
85 |
+
{"params": params_without_wd, "weight_decay": 0.0},
|
86 |
+
]
|
87 |
+
|
88 |
+
|
89 |
+
def lambda_policy_fn(module):
|
90 |
+
if (
|
91 |
+
len(list(module.named_children())) == 0
|
92 |
+
and getattr(module, "weight", None) is not None
|
93 |
+
and module.weight.requires_grad
|
94 |
+
):
|
95 |
+
return True
|
96 |
+
return False
|
97 |
+
|
98 |
+
|
99 |
+
def lambda_auto_wrap_policy(
|
100 |
+
module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
|
101 |
+
) -> bool:
|
102 |
+
"""
|
103 |
+
A convenient auto wrap policy to wrap submodules based on an arbitrary user
|
104 |
+
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
|
105 |
+
a `wrapper_cls` unit.
|
106 |
+
|
107 |
+
Return if a module should be wrapped during auto wrapping.
|
108 |
+
|
109 |
+
The first three parameters are required by :func:`_recursive_wrap`.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
module (nn.Module): Current module being considered.
|
113 |
+
recurse (bool): If ``False``, then this function must decide whether
|
114 |
+
``module`` should be wrapped as an FSDP instance or not. If
|
115 |
+
``True``, then the function is still recursing down the module
|
116 |
+
tree as a part of the DFS.
|
117 |
+
nonwrapped_numel (int): Parameter numel not yet wrapped.
|
118 |
+
|
119 |
+
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
|
120 |
+
this module will be wrapped.
|
121 |
+
"""
|
122 |
+
if recurse:
|
123 |
+
return True # always recurse
|
124 |
+
return lambda_fn(module)
|
125 |
+
|
126 |
+
|
127 |
+
def main():
|
128 |
+
parser = argparse.ArgumentParser()
|
129 |
+
parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
|
130 |
+
parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
|
131 |
+
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
|
132 |
+
parser.add_argument(
|
133 |
+
"--tokenizer_path",
|
134 |
+
default="facebook/opt-1.3b",
|
135 |
+
type=str,
|
136 |
+
help="path to tokenizer",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--run_name",
|
140 |
+
type=str,
|
141 |
+
default="openflamingo3B",
|
142 |
+
help="used to name saving directory and wandb run",
|
143 |
+
)
|
144 |
+
parser.add_argument("--use_media_placement_augmentation", action="store_true")
|
145 |
+
parser.add_argument("--offline", action="store_true")
|
146 |
+
parser.add_argument("--num_steps", type=int, default=300000)
|
147 |
+
parser.add_argument(
|
148 |
+
"--logging_steps", type=int, default=10, help="log loss every n steps"
|
149 |
+
)
|
150 |
+
# Sum of gradient optimization batch size
|
151 |
+
parser.add_argument("--batch_size_mmc4", type=int, default=128)
|
152 |
+
parser.add_argument("--batch_size_laion", type=int, default=128)
|
153 |
+
parser.add_argument("--batch_size_pile", type=int, default=128)
|
154 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
155 |
+
parser.add_argument(
|
156 |
+
"--resume_from_checkpoint",
|
157 |
+
type=str,
|
158 |
+
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
|
159 |
+
default=None,
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--delete_previous_checkpoint",
|
163 |
+
action="store_true",
|
164 |
+
help="delete previous checkpoint when saving new checkpoint",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--laion_shards",
|
168 |
+
type=str,
|
169 |
+
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--mmc4_shards",
|
173 |
+
type=str,
|
174 |
+
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--pile_shards",
|
178 |
+
type=str,
|
179 |
+
default=None,
|
180 |
+
help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
181 |
+
)
|
182 |
+
parser.add_argument("--seed", type=int, default=42)
|
183 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
184 |
+
parser.add_argument(
|
185 |
+
"--lr_scheduler",
|
186 |
+
default="constant",
|
187 |
+
type=str,
|
188 |
+
help="constant, linear, or cosine",
|
189 |
+
)
|
190 |
+
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
|
191 |
+
parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
|
192 |
+
parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
|
193 |
+
parser.add_argument("--loss_multiplier_det", type=float, default=1.0)
|
194 |
+
parser.add_argument("--loss_multiplier_rel", type=float, default=1.0)
|
195 |
+
parser.add_argument("--loss_multiplier_attn", type=float, default=1.0)
|
196 |
+
parser.add_argument("--warmup_steps", default=5000, type=int)
|
197 |
+
# weight decay is only apply to YOLOX head if using FSDP
|
198 |
+
# https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159
|
199 |
+
parser.add_argument("--weight_decay", default=0.05, type=float)
|
200 |
+
parser.add_argument(
|
201 |
+
"--precision",
|
202 |
+
choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
|
203 |
+
default="fp32",
|
204 |
+
help="Floating point precision.",
|
205 |
+
)
|
206 |
+
# data args
|
207 |
+
parser.add_argument("--workers", type=int, default=1)
|
208 |
+
parser.add_argument("--dataset_resampled", action="store_true")
|
209 |
+
# distributed training args
|
210 |
+
parser.add_argument(
|
211 |
+
"--dist-url",
|
212 |
+
default="env://",
|
213 |
+
type=str,
|
214 |
+
help="url used to set up distributed training",
|
215 |
+
)
|
216 |
+
parser.add_argument(
|
217 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--horovod",
|
221 |
+
default=False,
|
222 |
+
action="store_true",
|
223 |
+
help="Use horovod for distributed training.",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--no-set-device-rank",
|
227 |
+
default=False,
|
228 |
+
action="store_true",
|
229 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
230 |
+
)
|
231 |
+
# wandb args
|
232 |
+
parser.add_argument("--report_to_wandb", default=False, action="store_true")
|
233 |
+
parser.add_argument(
|
234 |
+
"--wandb_project",
|
235 |
+
type=str,
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--wandb_entity",
|
239 |
+
type=str,
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--save_checkpoints_to_wandb",
|
243 |
+
default=False,
|
244 |
+
action="store_true",
|
245 |
+
help="save checkpoints to wandb",
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--checkpoint_activations",
|
249 |
+
default=False,
|
250 |
+
action="store_true",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--freeze_vision_encoder",
|
254 |
+
default=False,
|
255 |
+
action="store_true",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--mmc4_textsim_threshold",
|
259 |
+
default=30,
|
260 |
+
type=float,
|
261 |
+
help="threshold for filtering images in mmc4 based on image-text similarity",
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--location_token_num",
|
265 |
+
default=1000,
|
266 |
+
type=int,
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--vis_embed_size",
|
270 |
+
type=int,
|
271 |
+
required=False,
|
272 |
+
)
|
273 |
+
parser.add_argument(
|
274 |
+
"--save_interval",
|
275 |
+
default=1000,
|
276 |
+
type=int,
|
277 |
+
required=False,
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--skip_delete_pattern",
|
281 |
+
default=1500,
|
282 |
+
type=int,
|
283 |
+
required=False,
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--ddp",
|
287 |
+
default=False,
|
288 |
+
action="store_true",
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--pile_freq",
|
292 |
+
default=1,
|
293 |
+
type=int,
|
294 |
+
required=False,
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--restart",
|
298 |
+
default=False,
|
299 |
+
action="store_true",
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"--lora",
|
303 |
+
default=False,
|
304 |
+
action="store_true",
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--lora_r",
|
308 |
+
default=16,
|
309 |
+
type=int,
|
310 |
+
required=False,
|
311 |
+
)
|
312 |
+
parser.add_argument(
|
313 |
+
"--single",
|
314 |
+
default=False,
|
315 |
+
action="store_true",
|
316 |
+
)
|
317 |
+
|
318 |
+
# Finetune
|
319 |
+
parser.add_argument(
|
320 |
+
"--instruct",
|
321 |
+
default=False,
|
322 |
+
action="store_true",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--fix-ffn",
|
326 |
+
default=False,
|
327 |
+
action="store_true",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--prob_ground",
|
331 |
+
default=1.0,
|
332 |
+
type=float,
|
333 |
+
required=False,
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--optimizer",
|
337 |
+
default="adamw",
|
338 |
+
type=str,
|
339 |
+
required=False,
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--add_visual_token",
|
343 |
+
default=False,
|
344 |
+
action="store_true",
|
345 |
+
)
|
346 |
+
parser.add_argument(
|
347 |
+
"--use_format_v2",
|
348 |
+
default=False,
|
349 |
+
action="store_true",
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--use_sam",
|
353 |
+
default=None,
|
354 |
+
type=str,
|
355 |
+
required=False,
|
356 |
+
)
|
357 |
+
parser.add_argument(
|
358 |
+
"--max-length",
|
359 |
+
default=608,
|
360 |
+
type=int,
|
361 |
+
required=False,
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--image-size",
|
365 |
+
default=256,
|
366 |
+
type=int,
|
367 |
+
required=False,
|
368 |
+
)
|
369 |
+
parser.add_argument(
|
370 |
+
"--reset_llm",
|
371 |
+
default=False,
|
372 |
+
action="store_true",
|
373 |
+
)
|
374 |
+
parser.add_argument(
|
375 |
+
"--add_box",
|
376 |
+
default=False,
|
377 |
+
action="store_true",
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--add_pe",
|
381 |
+
default=False,
|
382 |
+
action="store_true",
|
383 |
+
)
|
384 |
+
parser.add_argument(
|
385 |
+
"--only_grounded_sample",
|
386 |
+
default=False,
|
387 |
+
action="store_true",
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--expand",
|
391 |
+
default=False,
|
392 |
+
action="store_true",
|
393 |
+
)
|
394 |
+
parser.add_argument(
|
395 |
+
"--delete_contained",
|
396 |
+
default=False,
|
397 |
+
action="store_true",
|
398 |
+
)
|
399 |
+
|
400 |
+
parser.add_argument(
|
401 |
+
"--relation",
|
402 |
+
default=False,
|
403 |
+
action="store_true",
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--attn_reg",
|
407 |
+
default="l1",
|
408 |
+
type=str,
|
409 |
+
required=False,
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--enhance_data",
|
413 |
+
default=False,
|
414 |
+
action="store_true",
|
415 |
+
)
|
416 |
+
parser.add_argument(
|
417 |
+
"--no_visual",
|
418 |
+
default=False,
|
419 |
+
action="store_true",
|
420 |
+
)
|
421 |
+
parser.add_argument(
|
422 |
+
"--no_previsual",
|
423 |
+
default=False,
|
424 |
+
action="store_true",
|
425 |
+
)
|
426 |
+
parser.add_argument(
|
427 |
+
"--roi_align",
|
428 |
+
default=False,
|
429 |
+
action="store_true",
|
430 |
+
)
|
431 |
+
parser.add_argument(
|
432 |
+
"--roi_output_size",
|
433 |
+
default=4,
|
434 |
+
type=int,
|
435 |
+
required=False,
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--apply_mask",
|
439 |
+
default=False,
|
440 |
+
action="store_true",
|
441 |
+
)
|
442 |
+
parser.add_argument(
|
443 |
+
"--longer_previsual",
|
444 |
+
default=False,
|
445 |
+
action="store_true",
|
446 |
+
)
|
447 |
+
|
448 |
+
args = parser.parse_args()
|
449 |
+
assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
|
450 |
+
if args.no_previsual:
|
451 |
+
assert args.no_visual, "no_previsual MUST come with no_visual"
|
452 |
+
assert not args.enhance_data, "dont enable enhance_data"
|
453 |
+
|
454 |
+
if args.offline:
|
455 |
+
os.environ["WANDB_MODE"] = "offline"
|
456 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
457 |
+
|
458 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
459 |
+
print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
|
460 |
+
device_id = init_distributed_device(args)
|
461 |
+
|
462 |
+
random_seed(args.seed)
|
463 |
+
model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
|
464 |
+
args.vision_encoder_path,
|
465 |
+
args.vision_encoder_pretrained,
|
466 |
+
args.lm_path,
|
467 |
+
args.tokenizer_path if args.tokenizer_path else args.lm_path,
|
468 |
+
use_local_files=args.offline,
|
469 |
+
use_media_placement_augmentation=args.use_media_placement_augmentation,
|
470 |
+
checkpoint_activations=args.checkpoint_activations,
|
471 |
+
freeze_vision_encoder=args.freeze_vision_encoder,
|
472 |
+
location_token_num=args.location_token_num,
|
473 |
+
lora=args.lora,
|
474 |
+
lora_r=args.lora_r,
|
475 |
+
fix_ffn=args.fix_ffn,
|
476 |
+
add_visual_token=args.add_visual_token,
|
477 |
+
add_box=args.add_box,
|
478 |
+
add_pe=args.add_pe,
|
479 |
+
add_relation=args.relation,
|
480 |
+
use_format_v2=args.use_format_v2,
|
481 |
+
use_sam=args.use_sam,
|
482 |
+
enhance_data=args.enhance_data,
|
483 |
+
roi_align=args.roi_align,
|
484 |
+
roi_output_size=args.roi_output_size,
|
485 |
+
apply_mask=args.apply_mask,
|
486 |
+
)
|
487 |
+
if args.reset_llm:
|
488 |
+
llm_state_dict = model.lang_encoder.state_dict()
|
489 |
+
if args.rank == 0:
|
490 |
+
print(args)
|
491 |
+
print(image_processor)
|
492 |
+
|
493 |
+
random_seed(args.seed, args.rank)
|
494 |
+
|
495 |
+
if args.rank == 0 and args.report_to_wandb:
|
496 |
+
wandb.init(
|
497 |
+
project=args.wandb_project,
|
498 |
+
entity=args.wandb_entity,
|
499 |
+
name=args.run_name,
|
500 |
+
config=vars(args),
|
501 |
+
)
|
502 |
+
|
503 |
+
device_id = args.rank % torch.cuda.device_count()
|
504 |
+
if args.ddp:
|
505 |
+
print("use ddp mode")
|
506 |
+
model = model.to(device_id)
|
507 |
+
model = DDP(model)
|
508 |
+
else:
|
509 |
+
fpSixteen = MixedPrecision(
|
510 |
+
param_dtype=torch.float16,
|
511 |
+
# Gradient communication precision.
|
512 |
+
reduce_dtype=torch.float16,
|
513 |
+
# Buffer precision.
|
514 |
+
# buffer_dtype=torch.float16,
|
515 |
+
)
|
516 |
+
# from transformers.models.opt.modeling_opt import OPTDecoderLayer
|
517 |
+
from open_clip.transformer import ResidualAttentionBlock
|
518 |
+
from open_flamingo.src.flamingo_lm import FlamingoLayer
|
519 |
+
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
|
520 |
+
from segment_anything.modeling.image_encoder import Block
|
521 |
+
transformer_layer_cls=[
|
522 |
+
FlamingoLayer,
|
523 |
+
ResidualAttentionBlock,
|
524 |
+
Block,
|
525 |
+
]
|
526 |
+
if args.fix_ffn:
|
527 |
+
transformer_layer_cls.append(OPTAttention)
|
528 |
+
auto_wrap_policy = functools.partial(
|
529 |
+
transformer_auto_wrap_policy,
|
530 |
+
transformer_layer_cls=transformer_layer_cls,
|
531 |
+
)
|
532 |
+
if args.lora:
|
533 |
+
from torch.distributed.fsdp.wrap import _or_policy
|
534 |
+
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
|
535 |
+
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
|
536 |
+
ignored_modules = [model.vision_encoder]
|
537 |
+
# ignored_modules = None
|
538 |
+
else:
|
539 |
+
ignored_modules = [model.detection_head]
|
540 |
+
# ignored_modules = None
|
541 |
+
if args.add_pe:
|
542 |
+
ignored_modules += [model.pos_enc]
|
543 |
+
# if args.use_format_v2:
|
544 |
+
# ignored_modules += [model.lang_encoder.visual_guided_lm_head]
|
545 |
+
model = FSDP(
|
546 |
+
model,
|
547 |
+
auto_wrap_policy=auto_wrap_policy,
|
548 |
+
mixed_precision=fpSixteen,
|
549 |
+
device_id=torch.cuda.current_device(),
|
550 |
+
ignored_modules=ignored_modules,
|
551 |
+
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
552 |
+
)
|
553 |
+
model = model.to(device_id)
|
554 |
+
|
555 |
+
|
556 |
+
pile_dataset = None
|
557 |
+
if args.instruct:
|
558 |
+
laion_dataset = get_data(args, image_processor, tokenizer, "instruct")
|
559 |
+
else:
|
560 |
+
laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
|
561 |
+
if args.pile_shards is not None:
|
562 |
+
pile_dataset = get_data(args, image_processor, tokenizer, "pile")
|
563 |
+
|
564 |
+
|
565 |
+
optim_groups = get_grouped_params(model, args)
|
566 |
+
# optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
|
567 |
+
if args.ddp:
|
568 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
|
569 |
+
# optimizer = ZeroRedundancyOptimizer(
|
570 |
+
# optim_groups,
|
571 |
+
# optimizer_class=torch.optim.AdamW,
|
572 |
+
# lr=args.learning_rate,
|
573 |
+
# parameters_as_bucket_view=True,
|
574 |
+
# )
|
575 |
+
else:
|
576 |
+
if args.optimizer == "adamw":
|
577 |
+
print("use adamw")
|
578 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
|
579 |
+
elif args.optimizer == "sgd":
|
580 |
+
print("use sgd...")
|
581 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
|
582 |
+
else:
|
583 |
+
raise NotImplementedError
|
584 |
+
|
585 |
+
total_training_steps = args.num_steps
|
586 |
+
|
587 |
+
if args.rank == 0:
|
588 |
+
logging.info(f"Total training steps: {total_training_steps}")
|
589 |
+
|
590 |
+
if args.lr_scheduler == "linear":
|
591 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
592 |
+
optimizer,
|
593 |
+
num_warmup_steps=args.warmup_steps,
|
594 |
+
num_training_steps=total_training_steps,
|
595 |
+
)
|
596 |
+
elif args.lr_scheduler == "cosine":
|
597 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
598 |
+
optimizer,
|
599 |
+
num_warmup_steps=args.warmup_steps,
|
600 |
+
num_training_steps=total_training_steps,
|
601 |
+
)
|
602 |
+
else:
|
603 |
+
lr_scheduler = get_constant_schedule_with_warmup(
|
604 |
+
optimizer, num_warmup_steps=args.warmup_steps
|
605 |
+
)
|
606 |
+
if args.ddp:
|
607 |
+
scaler = GradScaler()
|
608 |
+
else:
|
609 |
+
scaler = ShardedGradScaler()
|
610 |
+
total_laion_token = 0
|
611 |
+
total_pile_token = 0
|
612 |
+
total_laion_sample = 0
|
613 |
+
total_step = 0
|
614 |
+
|
615 |
+
# check if a checkpoint exists for this run
|
616 |
+
if os.path.exists(f"{args.run_name}"):
|
617 |
+
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
|
618 |
+
if len(checkpoint_list) == 0:
|
619 |
+
if args.rank == 0:
|
620 |
+
logging.info(f"Found no checkpoints for run {args.run_name}.")
|
621 |
+
else:
|
622 |
+
args.resume_from_checkpoint = sorted(
|
623 |
+
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
624 |
+
)[-1]
|
625 |
+
if args.rank == 0:
|
626 |
+
logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
|
627 |
+
args.restart = False
|
628 |
+
if args.rank == 0:
|
629 |
+
logging.info("do not restart because an existed checkpoint is found")
|
630 |
+
if args.resume_from_checkpoint is not None:
|
631 |
+
if args.rank == 0:
|
632 |
+
logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}")
|
633 |
+
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
|
634 |
+
torch.distributed.barrier()
|
635 |
+
if args.ddp:
|
636 |
+
model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
637 |
+
# sharded_osd = checkpoint['optimizer_state_dict']
|
638 |
+
else:
|
639 |
+
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
640 |
+
if args.reset_llm:
|
641 |
+
for key in checkpoint["model_state_dict"]:
|
642 |
+
if key.startswith("lang_encoder"):
|
643 |
+
if args.rank == 0:
|
644 |
+
logging.info(f"reset {key}")
|
645 |
+
llm_key = key.replace("lang_encoder.", "")
|
646 |
+
checkpoint["model_state_dict"][key] = llm_state_dict[llm_key]
|
647 |
+
model_state_dict = model.state_dict()
|
648 |
+
for key in checkpoint["model_state_dict"].keys():
|
649 |
+
if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape:
|
650 |
+
if args.rank == 0:
|
651 |
+
logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}')
|
652 |
+
checkpoint["model_state_dict"][key] = model_state_dict[key].clone()
|
653 |
+
del model_state_dict
|
654 |
+
model.load_state_dict(checkpoint["model_state_dict"], False)
|
655 |
+
# sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
|
656 |
+
if not args.restart:
|
657 |
+
# optimizer.load_state_dict(sharded_osd)
|
658 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
|
659 |
+
# scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
660 |
+
total_laion_token = checkpoint.get("total_laion_token", 0)
|
661 |
+
total_pile_token = checkpoint.get("total_pile_token", 0)
|
662 |
+
total_laion_sample = checkpoint.get("total_laion_sample", 0)
|
663 |
+
total_step = checkpoint.get("total_step", 0)
|
664 |
+
if args.rank == 0:
|
665 |
+
logging.info("load training statistics...")
|
666 |
+
else:
|
667 |
+
if args.rank == 0:
|
668 |
+
logging.info("restart training / finetuning. only load model weight...")
|
669 |
+
del checkpoint
|
670 |
+
if args.reset_llm:
|
671 |
+
del llm_state_dict
|
672 |
+
torch.cuda.empty_cache()
|
673 |
+
torch.distributed.barrier()
|
674 |
+
|
675 |
+
model.train()
|
676 |
+
if args.rank == 0:
|
677 |
+
if not os.path.exists(args.run_name):
|
678 |
+
os.makedirs(args.run_name)
|
679 |
+
writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
|
680 |
+
else:
|
681 |
+
writer = None
|
682 |
+
|
683 |
+
laion_dataset.set_epoch(total_step)
|
684 |
+
laion_loader = laion_dataset.dataloader
|
685 |
+
if pile_dataset is not None:
|
686 |
+
pile_dataset.set_epoch(total_step)
|
687 |
+
pile_loader = pile_dataset.dataloader
|
688 |
+
else:
|
689 |
+
pile_loader = FakeDataloader()
|
690 |
+
train_one_epoch(
|
691 |
+
args=args,
|
692 |
+
model=model,
|
693 |
+
tokenizer=tokenizer,
|
694 |
+
optimizer=optimizer,
|
695 |
+
lr_scheduler=lr_scheduler,
|
696 |
+
laion_loader=laion_loader,
|
697 |
+
pile_loader=pile_loader,
|
698 |
+
device_id=device_id,
|
699 |
+
writer=writer,
|
700 |
+
scaler=scaler,
|
701 |
+
optim_groups=optim_groups,
|
702 |
+
total_laion_token=total_laion_token,
|
703 |
+
total_pile_token=total_pile_token,
|
704 |
+
total_laion_sample=total_laion_sample,
|
705 |
+
total_step=total_step,
|
706 |
+
)
|
707 |
+
|
708 |
+
if __name__ == "__main__":
|
709 |
+
main()
|
multimodal/build/lib/open_flamingo/train/train_utils.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from contextlib import suppress
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
import datetime
|
8 |
+
import os
|
9 |
+
import gc
|
10 |
+
from torch.distributed.fsdp import (
|
11 |
+
FullyShardedDataParallel as FSDP,
|
12 |
+
MixedPrecision,
|
13 |
+
BackwardPrefetch,
|
14 |
+
ShardingStrategy,
|
15 |
+
FullStateDictConfig,
|
16 |
+
StateDictType,
|
17 |
+
)
|
18 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
19 |
+
from torch.distributed.fsdp.wrap import (
|
20 |
+
transformer_auto_wrap_policy,
|
21 |
+
enable_wrap,
|
22 |
+
wrap,
|
23 |
+
)
|
24 |
+
|
25 |
+
from torch.utils.tensorboard import SummaryWriter
|
26 |
+
import logging
|
27 |
+
logging.basicConfig(
|
28 |
+
level=logging.INFO,
|
29 |
+
format='%(asctime)s %(message)s',
|
30 |
+
datefmt='%m/%d %I:%M:%S',
|
31 |
+
)
|
32 |
+
|
33 |
+
def get_cast_dtype(precision: str):
|
34 |
+
cast_dtype = None
|
35 |
+
if precision == "bf16":
|
36 |
+
cast_dtype = torch.bfloat16
|
37 |
+
elif precision == "fp16":
|
38 |
+
cast_dtype = torch.float16
|
39 |
+
return cast_dtype
|
40 |
+
|
41 |
+
|
42 |
+
def get_autocast(precision):
|
43 |
+
if precision == "amp_fp16":
|
44 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
|
45 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
46 |
+
# amp_bfloat16 is more stable than amp float16 for clip training
|
47 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
48 |
+
else:
|
49 |
+
return suppress
|
50 |
+
|
51 |
+
|
52 |
+
def get_sync(model, flag):
|
53 |
+
if flag:
|
54 |
+
return suppress
|
55 |
+
else:
|
56 |
+
return lambda: model.no_sync()
|
57 |
+
|
58 |
+
|
59 |
+
def train_one_epoch(
|
60 |
+
args,
|
61 |
+
model,
|
62 |
+
laion_loader,
|
63 |
+
pile_loader,
|
64 |
+
tokenizer,
|
65 |
+
optimizer,
|
66 |
+
lr_scheduler,
|
67 |
+
device_id,
|
68 |
+
writer: SummaryWriter,
|
69 |
+
optim_groups,
|
70 |
+
scaler,
|
71 |
+
total_laion_token: int,
|
72 |
+
total_pile_token: int,
|
73 |
+
total_laion_sample: int,
|
74 |
+
total_step: int,
|
75 |
+
):
|
76 |
+
world_size = torch.distributed.get_world_size()
|
77 |
+
autocast = get_autocast(args.precision)
|
78 |
+
cast_dtype = get_cast_dtype(args.precision)
|
79 |
+
|
80 |
+
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
81 |
+
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
|
82 |
+
visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
|
83 |
+
if args.add_box:
|
84 |
+
box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
85 |
+
endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
86 |
+
endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
|
87 |
+
if args.use_format_v2:
|
88 |
+
prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
|
89 |
+
previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
|
90 |
+
if args.rank == 0:
|
91 |
+
logging.info(f"train from: {total_step} step")
|
92 |
+
model.train()
|
93 |
+
# loop through dataloader
|
94 |
+
last_logging_step = total_step
|
95 |
+
last_save_step = total_step
|
96 |
+
for num_steps, (batch_laion, batch_pile) in tqdm(
|
97 |
+
enumerate(zip(laion_loader, pile_loader)),
|
98 |
+
disable=args.rank != 0 or "SLURM_PROCID" in os.environ,
|
99 |
+
total=args.num_steps * args.gradient_accumulation_steps,
|
100 |
+
initial=total_step * args.gradient_accumulation_steps,
|
101 |
+
):
|
102 |
+
#### LAION FORWARD PASS ####
|
103 |
+
images = (
|
104 |
+
batch_laion[0]
|
105 |
+
.to(device_id, dtype=cast_dtype, non_blocking=True)
|
106 |
+
.unsqueeze(1)
|
107 |
+
.unsqueeze(1)
|
108 |
+
)
|
109 |
+
image_nums = batch_laion[1]
|
110 |
+
image_start_index_list = batch_laion[2]
|
111 |
+
|
112 |
+
# TODO: OPT model: input_ids is not started with </s> while input_ids2 is?
|
113 |
+
input_ids = batch_laion[3].to(device_id, non_blocking=True).long()
|
114 |
+
attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True)
|
115 |
+
added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object
|
116 |
+
total_laion_token += int(attention_mask.sum().long()) * world_size
|
117 |
+
total_laion_sample += sum(image_nums) * world_size
|
118 |
+
|
119 |
+
labels = input_ids.clone()
|
120 |
+
if args.add_box:
|
121 |
+
labels[input_ids == visual_token_id] = -100
|
122 |
+
labels[input_ids == box_token_id] = -100
|
123 |
+
labels[input_ids == endofattr_token_id] = -100
|
124 |
+
if args.use_format_v2:
|
125 |
+
labels[input_ids == previsual_token_id] = -100
|
126 |
+
labels[input_ids == prebox_token_id] = -100
|
127 |
+
labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
|
128 |
+
labels[torch.roll(input_ids == box_token_id, 1)] = -100
|
129 |
+
labels[:, 0] = -100
|
130 |
+
labels[input_ids == tokenizer.pad_token_id] = -100
|
131 |
+
labels[input_ids == media_token_id] = -100
|
132 |
+
labels[input_ids == endofmedia_token_id] = -100
|
133 |
+
labels.to(device_id)
|
134 |
+
current_laion_num = input_ids.shape[0]
|
135 |
+
|
136 |
+
#### PILE FORWARD PASS ####
|
137 |
+
if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None:
|
138 |
+
input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long()
|
139 |
+
attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True)
|
140 |
+
input_length = input_ids.shape[-1]
|
141 |
+
|
142 |
+
input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1)
|
143 |
+
attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1)
|
144 |
+
|
145 |
+
labels2 = input_ids2.clone()
|
146 |
+
labels2[labels2 == tokenizer.pad_token_id] = -100
|
147 |
+
labels2[:, 0] = -100
|
148 |
+
labels2.to(device_id)
|
149 |
+
|
150 |
+
if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1:
|
151 |
+
image_nums = image_nums + [0] * len(input_ids2)
|
152 |
+
image_start_index_list = image_start_index_list + [[]] * len(input_ids2)
|
153 |
+
input_ids = torch.cat([input_ids, input_ids2], dim=0)
|
154 |
+
attention_mask = torch.cat([attention_mask, attention_mask2], dim=0)
|
155 |
+
labels = torch.cat([labels, labels2], dim=0)
|
156 |
+
total_pile_token += int(attention_mask2.sum().long()) * world_size
|
157 |
+
else:
|
158 |
+
del input_ids2
|
159 |
+
del attention_mask2
|
160 |
+
del labels2
|
161 |
+
|
162 |
+
if args.instruct:
|
163 |
+
answer_token_id = tokenizer(" Answer").input_ids[0]
|
164 |
+
answer_token_loc = (input_ids == answer_token_id).nonzero()
|
165 |
+
for batch_idx, idx in answer_token_loc:
|
166 |
+
labels[batch_idx][:idx+2] = -100
|
167 |
+
|
168 |
+
if args.relation and not args.instruct:
|
169 |
+
relations = batch_laion[6]
|
170 |
+
else:
|
171 |
+
relations = None
|
172 |
+
if len(added_bbox_list) == 0:
|
173 |
+
added_bbox_list = None
|
174 |
+
update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1
|
175 |
+
# do_sync = get_sync(model, update_flag)
|
176 |
+
with autocast():
|
177 |
+
# modify:
|
178 |
+
# /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py
|
179 |
+
# /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py
|
180 |
+
# CrossEntropyLoss(reduction="none")
|
181 |
+
outputs = model(
|
182 |
+
vision_x=images,
|
183 |
+
lang_x=input_ids,
|
184 |
+
attention_mask=attention_mask,
|
185 |
+
labels=labels,
|
186 |
+
image_nums=image_nums,
|
187 |
+
image_start_index_list=image_start_index_list,
|
188 |
+
added_bbox_list=added_bbox_list,
|
189 |
+
add_box=args.add_box,
|
190 |
+
relations=relations,
|
191 |
+
)
|
192 |
+
loss_total = outputs.loss.reshape(labels.shape[0], -1)
|
193 |
+
loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1)
|
194 |
+
loss_sample_for_laion = loss_sample[:current_laion_num]
|
195 |
+
nan_mask = torch.isnan(loss_sample_for_laion)
|
196 |
+
if nan_mask.sum() > 0:
|
197 |
+
logging.warning(f"caption NaN: {nan_mask}")
|
198 |
+
if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid:
|
199 |
+
logging.info("WARNING: skip this caption loss due to some error")
|
200 |
+
loss_laion = torch.tensor(0.0).cuda()
|
201 |
+
else:
|
202 |
+
loss_laion = loss_sample_for_laion[~nan_mask].mean()
|
203 |
+
loss_caption = loss_laion
|
204 |
+
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
|
205 |
+
if current_laion_num != loss_sample.shape[0]:
|
206 |
+
loss_pile = loss_sample[current_laion_num:].mean()
|
207 |
+
else:
|
208 |
+
loss_pile = torch.tensor(0.0).cuda()
|
209 |
+
divided_loss_pile = loss_pile / args.gradient_accumulation_steps
|
210 |
+
|
211 |
+
if "detection_losses" in outputs:
|
212 |
+
loss_det = outputs["detection_losses"]["loss"]
|
213 |
+
loss_iou = outputs["detection_losses"]["loss_iou"]
|
214 |
+
loss_obj = outputs["detection_losses"]["loss_obj"]
|
215 |
+
loss_cls = outputs["detection_losses"]["loss_cls"]
|
216 |
+
else:
|
217 |
+
loss_det = torch.tensor(0.0).cuda()
|
218 |
+
loss_iou = torch.tensor(0.0).cuda()
|
219 |
+
loss_obj = torch.tensor(0.0).cuda()
|
220 |
+
loss_cls = torch.tensor(0.0).cuda()
|
221 |
+
|
222 |
+
if "loss_dict" in outputs:
|
223 |
+
visual_loss_iou = outputs["loss_dict"][0]["loss_iou"]
|
224 |
+
previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"]
|
225 |
+
visual_loss_obj = outputs["loss_dict"][0]["loss_obj"]
|
226 |
+
previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"]
|
227 |
+
else:
|
228 |
+
visual_loss_iou = torch.tensor(0.0).cuda()
|
229 |
+
previsual_loss_iou = torch.tensor(0.0).cuda()
|
230 |
+
visual_loss_obj = torch.tensor(0.0).cuda()
|
231 |
+
previsual_loss_obj = torch.tensor(0.0).cuda()
|
232 |
+
|
233 |
+
divided_loss_det = loss_det / args.gradient_accumulation_steps
|
234 |
+
loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda())
|
235 |
+
divided_loss_rel = loss_rel / args.gradient_accumulation_steps
|
236 |
+
loss = (
|
237 |
+
divided_loss_laion * args.loss_multiplier_laion +
|
238 |
+
divided_loss_pile * args.loss_multiplier_pile +
|
239 |
+
divided_loss_det * args.loss_multiplier_det +
|
240 |
+
divided_loss_rel * args.loss_multiplier_rel
|
241 |
+
)
|
242 |
+
|
243 |
+
scaler.scale(loss).backward()
|
244 |
+
|
245 |
+
# for logging only
|
246 |
+
loss = (
|
247 |
+
loss_laion * args.loss_multiplier_laion
|
248 |
+
+ loss_pile * args.loss_multiplier_pile
|
249 |
+
+ loss_det * args.loss_multiplier_det
|
250 |
+
+ loss_rel * args.loss_multiplier_rel
|
251 |
+
).detach()
|
252 |
+
|
253 |
+
# step optimizer and log
|
254 |
+
if update_flag:
|
255 |
+
#### MASK GRADIENTS FOR EMBEDDINGS ####
|
256 |
+
# Note (anas): Do not apply weight decay to embeddings as it will break this function.
|
257 |
+
# ! not an important point
|
258 |
+
# if args.ddp:
|
259 |
+
# def mask_embedding(m):
|
260 |
+
# if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
|
261 |
+
# zero_mask = torch.zeros_like(m.weight.grad)
|
262 |
+
# zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
|
263 |
+
# zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id])
|
264 |
+
# m.weight.grad = m.weight.grad * zero_mask
|
265 |
+
# model.apply(mask_embedding)
|
266 |
+
total_step += 1
|
267 |
+
scaler.unscale_(optimizer)
|
268 |
+
if args.ddp:
|
269 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
270 |
+
else:
|
271 |
+
model.clip_grad_norm_(1.0)
|
272 |
+
scaler.step(optimizer)
|
273 |
+
scaler.update()
|
274 |
+
lr_scheduler.step()
|
275 |
+
optimizer.zero_grad()
|
276 |
+
# https://github.com/facebookresearch/fairscale/issues/627
|
277 |
+
model.zero_grad(set_to_none=True)
|
278 |
+
|
279 |
+
if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step:
|
280 |
+
last_logging_step = total_step
|
281 |
+
global_step = total_step
|
282 |
+
lr = optimizer.param_groups[0]["lr"]
|
283 |
+
writer.add_scalar("lr", lr, global_step)
|
284 |
+
writer.add_scalar("scale", scaler.get_scale(), global_step)
|
285 |
+
writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step)
|
286 |
+
writer.add_scalar("loss_laion", loss_caption.item(), global_step)
|
287 |
+
writer.add_scalar("loss_pile", loss_pile.item(), global_step)
|
288 |
+
writer.add_scalar("loss", loss.item(), global_step)
|
289 |
+
writer.add_scalar("loss_det", loss_det.item(), global_step)
|
290 |
+
writer.add_scalar("loss_iou", loss_iou.item(), global_step)
|
291 |
+
writer.add_scalar("loss_obj", loss_obj.item(), global_step)
|
292 |
+
writer.add_scalar("loss_cls", loss_cls.item(), global_step)
|
293 |
+
if loss_rel.item() != 0:
|
294 |
+
writer.add_scalar("loss_rel", loss_rel.item(), global_step)
|
295 |
+
if args.use_format_v2:
|
296 |
+
writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step)
|
297 |
+
writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step)
|
298 |
+
writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step)
|
299 |
+
writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step)
|
300 |
+
|
301 |
+
global_sample_num = total_laion_sample
|
302 |
+
writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num)
|
303 |
+
writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num)
|
304 |
+
writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num)
|
305 |
+
writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num)
|
306 |
+
writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num)
|
307 |
+
writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num)
|
308 |
+
writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num)
|
309 |
+
if loss_rel.item() != 0:
|
310 |
+
writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num)
|
311 |
+
writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num)
|
312 |
+
|
313 |
+
writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token)
|
314 |
+
writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token)
|
315 |
+
writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token)
|
316 |
+
writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token)
|
317 |
+
writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token)
|
318 |
+
writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token)
|
319 |
+
writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token)
|
320 |
+
if loss_rel.item() != 0:
|
321 |
+
writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token)
|
322 |
+
|
323 |
+
total_token = total_laion_token + total_pile_token
|
324 |
+
writer.add_scalar("sample_num", global_sample_num, global_step)
|
325 |
+
writer.add_scalar("total_laion_token", total_laion_token, global_step)
|
326 |
+
writer.add_scalar("total_pile_token", total_pile_token, global_step)
|
327 |
+
writer.add_scalar("total_token", total_token, global_step)
|
328 |
+
logging.info(
|
329 |
+
f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} // laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}"
|
330 |
+
)
|
331 |
+
|
332 |
+
if total_step % args.save_interval == 0 and total_step != last_save_step:
|
333 |
+
last_save_step = total_step
|
334 |
+
torch.distributed.barrier()
|
335 |
+
if args.ddp:
|
336 |
+
cpu_state = model.state_dict()
|
337 |
+
# if args.rank == 0:
|
338 |
+
# optimizer_state = optimizer.state_dict()
|
339 |
+
else:
|
340 |
+
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
341 |
+
with FSDP.state_dict_type(
|
342 |
+
model, StateDictType.FULL_STATE_DICT, save_policy
|
343 |
+
):
|
344 |
+
cpu_state = model.state_dict()
|
345 |
+
torch.distributed.barrier()
|
346 |
+
# https://pytorch.org/docs/1.12/fsdp.html
|
347 |
+
# need to pass optim_groups as optim_input
|
348 |
+
# optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups)
|
349 |
+
if args.rank == 0:
|
350 |
+
checkpoint_dict = {
|
351 |
+
"model_state_dict": cpu_state,
|
352 |
+
# "optimizer_state_dict": optimizer_state,
|
353 |
+
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
|
354 |
+
"scaler_state_dict": scaler.state_dict(),
|
355 |
+
"total_pile_token": total_pile_token,
|
356 |
+
"total_laion_token": total_laion_token,
|
357 |
+
"total_laion_sample": total_laion_sample,
|
358 |
+
"total_step": total_step,
|
359 |
+
}
|
360 |
+
logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt")
|
361 |
+
torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt")
|
362 |
+
del checkpoint_dict
|
363 |
+
if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0:
|
364 |
+
try:
|
365 |
+
os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt")
|
366 |
+
except:
|
367 |
+
pass
|
368 |
+
torch.distributed.barrier()
|
369 |
+
|
370 |
+
|
371 |
+
class AverageMeter(object):
|
372 |
+
"""Computes and stores the average and current value"""
|
373 |
+
|
374 |
+
def __init__(self):
|
375 |
+
self.reset()
|
376 |
+
|
377 |
+
def reset(self):
|
378 |
+
self.val = 0
|
379 |
+
self.avg = 0
|
380 |
+
self.sum = 0
|
381 |
+
self.count = 0
|
382 |
+
|
383 |
+
def update(self, val, n=1):
|
384 |
+
self.val = val
|
385 |
+
self.sum += val * n
|
386 |
+
self.count += n
|
387 |
+
self.avg = self.sum / self.count
|
multimodal/open_flamingo.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: open-flamingo
|
3 |
+
Version: 0.0.2
|
4 |
+
Summary: An open-source framework for training large multimodal models
|
5 |
+
License: MIT
|
6 |
+
Keywords: machine learning
|
7 |
+
Classifier: Development Status :: 4 - Beta
|
8 |
+
Classifier: Intended Audience :: Developers
|
9 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
10 |
+
Classifier: License :: OSI Approved :: MIT License
|
11 |
+
Classifier: Programming Language :: Python :: 3.9
|
12 |
+
Description-Content-Type: text/markdown
|
13 |
+
License-File: LICENSE
|
14 |
+
|
15 |
+
# 🦩 OpenFlamingo
|
16 |
+
|
17 |
+
[![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo)
|
18 |
+
|
19 |
+
[Blog post](https://laion.ai/blog/open-flamingo/) | Paper (coming soon)
|
20 |
+
|
21 |
+
Welcome to our open source version of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) model! In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. We also provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) trained on a new Multimodal C4 dataset (coming soon). Please refer to our blog post for more details.
|
22 |
+
|
23 |
+
This repo is still under development, and we hope to release better performing and larger OpenFlamingo models soon. If you have any questions, please feel free to open an issue. We also welcome contributions!
|
24 |
+
|
25 |
+
# Table of Contents
|
26 |
+
- [Installation](#installation)
|
27 |
+
- [Approach](#approach)
|
28 |
+
* [Model architecture](#model-architecture)
|
29 |
+
- [Usage](#usage)
|
30 |
+
* [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model)
|
31 |
+
* [Generating text](#generating-text)
|
32 |
+
- [Training](#training)
|
33 |
+
* [Dataset](#dataset)
|
34 |
+
- [Evaluation](#evaluation)
|
35 |
+
- [Future plans](#future-plans)
|
36 |
+
- [Team](#team)
|
37 |
+
- [Acknowledgments](#acknowledgments)
|
38 |
+
- [Citing](#citing)
|
39 |
+
|
40 |
+
# Installation
|
41 |
+
|
42 |
+
To install the package in an existing environment, run
|
43 |
+
```
|
44 |
+
pip install open-flamingo
|
45 |
+
```
|
46 |
+
|
47 |
+
or to create a conda environment for running OpenFlamingo, run
|
48 |
+
```
|
49 |
+
conda env create -f environment.yml
|
50 |
+
```
|
51 |
+
|
52 |
+
# Usage
|
53 |
+
We provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) using a CLIP ViT-Large vision encoder and a LLaMA-7B language model. In general, we support any [CLIP vision encoder](https://huggingface.co/models?search=clip). For the language model, we support [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models.
|
54 |
+
|
55 |
+
#### NOTE: To use LLaMA models, you will need to install the latest version of transformers via
|
56 |
+
```
|
57 |
+
pip install git+https://github.com/huggingface/transformers
|
58 |
+
```
|
59 |
+
Use this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to HuggingFace format.
|
60 |
+
|
61 |
+
## Initializing an OpenFlamingo model
|
62 |
+
``` python
|
63 |
+
from open_flamingo import create_model_and_transforms
|
64 |
+
|
65 |
+
model, image_processor, tokenizer = create_model_and_transforms(
|
66 |
+
clip_vision_encoder_path="ViT-L-14",
|
67 |
+
clip_vision_encoder_pretrained="openai",
|
68 |
+
lang_encoder_path="<path to llama weights in HuggingFace format>",
|
69 |
+
tokenizer_path="<path to llama tokenizer in HuggingFace format>",
|
70 |
+
cross_attn_every_n_layers=4
|
71 |
+
)
|
72 |
+
|
73 |
+
# grab model checkpoint from huggingface hub
|
74 |
+
from huggingface_hub import hf_hub_download
|
75 |
+
import torch
|
76 |
+
|
77 |
+
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt")
|
78 |
+
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
79 |
+
```
|
80 |
+
|
81 |
+
## Generating text
|
82 |
+
Here is an example of generating text conditioned on interleaved images/text, in this case we will do few-shot image captioning.
|
83 |
+
|
84 |
+
``` python
|
85 |
+
from PIL import Image
|
86 |
+
import requests
|
87 |
+
|
88 |
+
"""
|
89 |
+
Step 1: Load images
|
90 |
+
"""
|
91 |
+
demo_image_one = Image.open(
|
92 |
+
requests.get(
|
93 |
+
"http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
|
94 |
+
).raw
|
95 |
+
)
|
96 |
+
|
97 |
+
demo_image_two = Image.open(
|
98 |
+
requests.get(
|
99 |
+
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
|
100 |
+
stream=True
|
101 |
+
).raw
|
102 |
+
)
|
103 |
+
|
104 |
+
query_image = Image.open(
|
105 |
+
requests.get(
|
106 |
+
"http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
|
107 |
+
stream=True
|
108 |
+
).raw
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
"""
|
113 |
+
Step 2: Preprocessing images
|
114 |
+
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
|
115 |
+
batch_size x num_media x num_frames x channels x height x width.
|
116 |
+
In this case batch_size = 1, num_media = 3, num_frames = 1
|
117 |
+
(this will always be one expect for video which we don't support yet),
|
118 |
+
channels = 3, height = 224, width = 224.
|
119 |
+
"""
|
120 |
+
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
|
121 |
+
vision_x = torch.cat(vision_x, dim=0)
|
122 |
+
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
123 |
+
|
124 |
+
"""
|
125 |
+
Step 3: Preprocessing text
|
126 |
+
Details: In the text we expect an <|#image#|> special token to indicate where an image is.
|
127 |
+
We also expect an <|endofchunk|> special token to indicate the end of the text
|
128 |
+
portion associated with an image.
|
129 |
+
"""
|
130 |
+
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
|
131 |
+
lang_x = tokenizer(
|
132 |
+
["<|#image#|>An image of two cats.<|endofchunk|><|#image#|>An image of a bathroom sink.<|endofchunk|><|#image#|>An image of"],
|
133 |
+
return_tensors="pt",
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
"""
|
138 |
+
Step 4: Generate text
|
139 |
+
"""
|
140 |
+
generated_text = model.generate(
|
141 |
+
vision_x=vision_x,
|
142 |
+
lang_x=lang_x["input_ids"],
|
143 |
+
attention_mask=lang_x["attention_mask"],
|
144 |
+
max_new_tokens=20,
|
145 |
+
num_beams=3,
|
146 |
+
)
|
147 |
+
|
148 |
+
print("Generated text: ", tokenizer.decode(generated_text[0]))
|
149 |
+
```
|
150 |
+
|
151 |
+
# Approach
|
152 |
+
OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context training.
|
153 |
+
|
154 |
+
## Model architecture
|
155 |
+
OpenFlamingo seeks to fuse a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.
|
156 |
+
|
157 |
+
![OpenFlamingo architecture](docs/flamingo.png)
|
158 |
+
Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)
|
159 |
+
|
160 |
+
# Training
|
161 |
+
To train a model, modify the following example command, which uses OPT 1.3B as an example LM:
|
162 |
+
```
|
163 |
+
torchrun --nnodes=1 --nproc_per_node=4 train.py \
|
164 |
+
--run_name flamingo3B \
|
165 |
+
--lm_path facebook/opt-1.3b \
|
166 |
+
--tokenizer_path facebook/opt-1.3b \
|
167 |
+
--dataset_resampled \
|
168 |
+
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
|
169 |
+
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
|
170 |
+
--batch_size_mmc4 4 \
|
171 |
+
--batch_size_laion 8 \
|
172 |
+
--train_num_samples_mmc4 125000 \
|
173 |
+
--train_num_samples_laion 250000 \
|
174 |
+
--loss_multiplier_laion 0.2 \
|
175 |
+
--workers=6 \
|
176 |
+
--num_epochs 250 \
|
177 |
+
--lr_scheduler constant \
|
178 |
+
--warmup_steps 5000 \
|
179 |
+
--use_media_placement_augmentation \
|
180 |
+
--mmc4_textsim_threshold 30
|
181 |
+
```
|
182 |
+
|
183 |
+
## Dataset
|
184 |
+
We expect all our training datasets to be [WebDataset](https://github.com/webdataset/webdataset) shards.
|
185 |
+
We train our models on the [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and Multimodal C4 (coming soon) datasets. By default the LAION 2B dataset is in WebDataset format if it is downloaded using the [img2dataset tool](https://github.com/rom1504/img2dataset) and Multimodal C4 comes packaged in the WebDataset format.
|
186 |
+
|
187 |
+
|
188 |
+
# Evaluation
|
189 |
+
We currently support running evaluations on [COCO](https://cocodataset.org/#home), [VQAv2](https://visualqa.org/index.html), [OKVQA](https://okvqa.allenai.org), [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset), and [ImageNet](https://image-net.org/index.php). Note that currently these evaluations are ran in validation mode (as specified in the Flamingo paper). We will be adding support for running evaluations in test mode in the future.
|
190 |
+
|
191 |
+
Before evaluating the model, you will need to install the coco evaluation package by running the following command:
|
192 |
+
```
|
193 |
+
pip install pycocoevalcap
|
194 |
+
```
|
195 |
+
|
196 |
+
To run evaluations on OKVQA you will need to run the following command:
|
197 |
+
```
|
198 |
+
import nltk
|
199 |
+
nltk.download('wordnet')
|
200 |
+
```
|
201 |
+
|
202 |
+
To evaluate the model, run the script at `open_flamingo/scripts/run_eval.sh`
|
203 |
+
|
204 |
+
# Future plans
|
205 |
+
- [ ] Add support for video input
|
206 |
+
- [ ] Release better performing and larger OpenFlamingo models
|
207 |
+
- [ ] Expand our evaluation suite
|
208 |
+
- [ ] Add support for FSDP training
|
209 |
+
|
210 |
+
# Team
|
211 |
+
|
212 |
+
OpenFlamingo is developed by:
|
213 |
+
|
214 |
+
[Anas Awadalla](https://anas-awadalla.streamlit.app/), [Irena Gao](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/).
|
215 |
+
|
216 |
+
The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google.
|
217 |
+
|
218 |
+
# Acknowledgments
|
219 |
+
This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design.
|
220 |
+
|
221 |
+
We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models.
|
222 |
+
|
223 |
+
# Citing
|
224 |
+
If you found this repository useful, please consider citing:
|
225 |
+
|
226 |
+
```
|
227 |
+
@software{anas_awadalla_2023_7733589,
|
228 |
+
author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig},
|
229 |
+
title = {OpenFlamingo},
|
230 |
+
month = mar,
|
231 |
+
year = 2023,
|
232 |
+
publisher = {Zenodo},
|
233 |
+
version = {v0.1.1},
|
234 |
+
doi = {10.5281/zenodo.7733589},
|
235 |
+
url = {https://doi.org/10.5281/zenodo.7733589}
|
236 |
+
}
|
237 |
+
```
|
238 |
+
|
239 |
+
```
|
240 |
+
@article{Alayrac2022FlamingoAV,
|
241 |
+
title={Flamingo: a Visual Language Model for Few-Shot Learning},
|
242 |
+
author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
|
243 |
+
journal={ArXiv},
|
244 |
+
year={2022},
|
245 |
+
volume={abs/2204.14198}
|
246 |
+
}
|
247 |
+
```
|
multimodal/open_flamingo.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
README.md
|
3 |
+
setup.py
|
4 |
+
open_flamingo/__init__.py
|
5 |
+
open_flamingo.egg-info/PKG-INFO
|
6 |
+
open_flamingo.egg-info/SOURCES.txt
|
7 |
+
open_flamingo.egg-info/dependency_links.txt
|
8 |
+
open_flamingo.egg-info/requires.txt
|
9 |
+
open_flamingo.egg-info/top_level.txt
|
10 |
+
open_flamingo/chat/__init__.py
|
11 |
+
open_flamingo/chat/conversation.py
|
12 |
+
open_flamingo/eval/__init__.py
|
13 |
+
open_flamingo/eval/classification.py
|
14 |
+
open_flamingo/eval/coco_metric.py
|
15 |
+
open_flamingo/eval/eval_datasets.py
|
16 |
+
open_flamingo/eval/evaluate.py
|
17 |
+
open_flamingo/eval/evaluate_debug.py
|
18 |
+
open_flamingo/eval/evaluate_find_showcase.py
|
19 |
+
open_flamingo/eval/evaluate_temp.py
|
20 |
+
open_flamingo/eval/imagenet_utils.py
|
21 |
+
open_flamingo/eval/ok_vqa_utils.py
|
22 |
+
open_flamingo/eval/vqa_metric.py
|
23 |
+
open_flamingo/eval/dataset_zoo/__init__.py
|
24 |
+
open_flamingo/eval/dataset_zoo/aro_datasets.py
|
25 |
+
open_flamingo/eval/dataset_zoo/constants.py
|
26 |
+
open_flamingo/eval/dataset_zoo/perturbations.py
|
27 |
+
open_flamingo/eval/dataset_zoo/retrieval.py
|
28 |
+
open_flamingo/eval/dataset_zoo/utils.py
|
29 |
+
open_flamingo/eval/task/__init__.py
|
30 |
+
open_flamingo/eval/task/caption.py
|
31 |
+
open_flamingo/eval/task/caption_chat.py
|
32 |
+
open_flamingo/eval/task/cola.py
|
33 |
+
open_flamingo/eval/task/crepe.py
|
34 |
+
open_flamingo/eval/task/gqa.py
|
35 |
+
open_flamingo/eval/task/mmbench.py
|
36 |
+
open_flamingo/eval/task/reg.py
|
37 |
+
open_flamingo/eval/task/utils.py
|
38 |
+
open_flamingo/eval/task/vl_checklist.py
|
39 |
+
open_flamingo/src/__init__.py
|
40 |
+
open_flamingo/src/attention.py
|
41 |
+
open_flamingo/src/factory.py
|
42 |
+
open_flamingo/src/flamingo.py
|
43 |
+
open_flamingo/src/flamingo_lm.py
|
44 |
+
open_flamingo/src/gcn.py
|
45 |
+
open_flamingo/src/helpers.py
|
46 |
+
open_flamingo/src/utils.py
|
47 |
+
open_flamingo/train/__init__.py
|
48 |
+
open_flamingo/train/data2.py
|
49 |
+
open_flamingo/train/distributed.py
|
50 |
+
open_flamingo/train/instruction_template.py
|
51 |
+
open_flamingo/train/train.py
|
52 |
+
open_flamingo/train/train_utils.py
|
53 |
+
tests/test_flamingo_model.py
|
multimodal/open_flamingo.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
multimodal/open_flamingo.egg-info/requires.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
einops-exts
|
3 |
+
transformers==4.31.0
|
4 |
+
torch==1.12.1
|
5 |
+
torchvision==0.13.1
|
6 |
+
pillow==9.3.0
|
7 |
+
more-itertools
|
8 |
+
datasets==2.9.0
|
9 |
+
braceexpand==0.1.7
|
10 |
+
webdataset
|
11 |
+
wandb==0.13.10
|
12 |
+
nltk
|
13 |
+
scipy
|
14 |
+
inflection
|
15 |
+
sentencepiece
|
16 |
+
open_clip_torch==2.20.0
|
17 |
+
opencv-python==4.7.0.68
|
multimodal/open_flamingo.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
open_flamingo
|