Spaces:
Runtime error
Runtime error
File size: 3,298 Bytes
e770d90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import json
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
class COCOFlickrDataset(Dataset):
def __init__(
self,
image_dir_path,
annotations_path,
is_flickr=False,
):
self.image_dir_path = image_dir_path
self.annotations = json.load(open(annotations_path))["annotations"]
self.is_flickr = is_flickr
def __len__(self):
return len(self.annotations)
def get_img_path(self, idx):
if self.is_flickr:
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
else:
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg"
def __getitem__(self, idx):
image = Image.open(self.get_img_path(idx))
caption = self.annotations[idx]["caption"]
return {
"image": image,
"caption": caption,
"image_id": self.annotations[idx]["image_id"],
}
class VQADataset(Dataset):
def __init__(
self,
image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/",
question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json",
annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json",
vqa_dataset="vqa",
):
self.questions = json.load(open(question_path, "r"))["questions"]
self.answers = json.load(open(annotations_path, "r"))["annotations"]
self.image_dir_path = image_dir_path
self.vqa_dataset = vqa_dataset
def __len__(self):
return len(self.questions)
def get_img_path(self, question):
if self.vqa_dataset == "vqa":
return os.path.join(
self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
)
elif self.vqa_dataset == "ok_vqa":
return os.path.join(
self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
)
else:
raise Exception(f"Unknown VQA dataset {self.vqa_dataset}")
def __getitem__(self, idx):
question = self.questions[idx]
answers = self.answers[idx]
img_path = self.get_img_path(question)
image = Image.open(img_path)
return {
"image": image,
"question": question["question"],
"answers": [a["answer"] for a in answers["answers"]],
"question_id": question["question_id"],
}
class ImageNetDataset(ImageFolder):
"""Class to represent the ImageNet1k dataset."""
def __init__(self, root, **kwargs):
super().__init__(root=root, **kwargs)
def __getitem__(self, idx):
sample, target = super().__getitem__(idx)
target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
return {
"image": sample,
"class_id": target, # numeric ID of the ImageNet class
"class_name": target_label, # human-readable name of ImageNet class
}
if __name__ == "__main__":
gqa_dataset = GQADataset()
for sample in gqa_dataset:
print(sample)
|