Spaces:
Runtime error
Runtime error
import base64 | |
import io | |
import random | |
import pandas as pd | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from open_flamingo.eval.task.utils import get_object_from_text | |
def decode_base64_to_image(base64_string): | |
image_data = base64.b64decode(base64_string) | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
class MMBenchDataset(Dataset): | |
def __init__(self, | |
data_file, | |
sys_prompt='There are several options:'): | |
self.df = pd.read_csv(data_file, sep='\t') | |
self.sys_prompt = sys_prompt | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx): | |
index = self.df.iloc[idx]['index'] | |
image = self.df.iloc[idx]['image'] | |
image = decode_base64_to_image(image) | |
question = self.df.iloc[idx]['question'] | |
answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None | |
catetory = self.df.iloc[idx]['category'] | |
l2_catetory = self.df.iloc[idx]['l2-category'] | |
option_candidate = ['A', 'B', 'C', 'D', 'E'] | |
options = { | |
cand: self.load_from_df(idx, cand) | |
for cand in option_candidate | |
if self.load_from_df(idx, cand) is not None | |
} | |
options_prompt = f'{self.sys_prompt}\n' | |
for key, item in options.items(): | |
options_prompt += f'{key}. {item}\n' | |
hint = self.load_from_df(idx, 'hint') | |
data = { | |
'img': image, | |
'question': question, | |
'answer': answer, | |
'options': options_prompt, | |
'category': catetory, | |
'l2-category': l2_catetory, | |
'options_dict': options, | |
'index': index, | |
'context': hint, | |
} | |
return data | |
def load_from_df(self, idx, key): | |
if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]): | |
return self.df.iloc[idx][key] | |
else: | |
return None | |
def evaluate_mmbench( | |
model, | |
tokenizer, | |
image_processor, | |
batch_size=1, | |
image_dir_path=None, | |
questions_json_path=None, | |
annotations_json_path=None, | |
vis_embed_size=None, | |
rank=0, | |
world_size=1, | |
id=0, | |
): | |
dataset_name = "mmbench" | |
dataset = MMBenchDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/mmbench/mmbench_dev_20230712.tsv") | |
for sample in dataset: | |
print(sample) | |
if __name__ == '__main__': | |
evaluate_mmbench(None, None, None) | |