|
"""Modified from https://github.com/mlfoundations/open_clip""" |
|
|
|
from typing import Optional, Tuple |
|
|
|
import collections |
|
import logging |
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torchvision.datasets as datasets |
|
from torchvision import transforms as T |
|
from PIL import Image, ImageFont |
|
from torch.utils.data import Dataset |
|
|
|
from fromage import utils |
|
|
|
|
|
def collate_fn(batch): |
|
batch = list(filter(lambda x: x is not None, batch)) |
|
return torch.utils.data.dataloader.default_collate(batch) |
|
|
|
|
|
def get_dataset(args, split: str, tokenizer, precision: str = 'fp32') -> Dataset: |
|
assert split in ['train', 'val' |
|
], 'Expected split to be one of "train" or "val", got {split} instead.' |
|
|
|
dataset_paths = [] |
|
image_data_dirs = [] |
|
train = split == 'train' |
|
|
|
|
|
|
|
if split == 'train': |
|
if 'cc3m' in args.dataset: |
|
dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_train.tsv')) |
|
image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/training/')) |
|
else: |
|
raise NotImplementedError |
|
|
|
elif split == 'val': |
|
if 'cc3m' in args.val_dataset: |
|
dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_val.tsv')) |
|
image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/validation')) |
|
else: |
|
raise NotImplementedError |
|
|
|
assert len(dataset_paths) == len(image_data_dirs) == 1, (dataset_paths, image_data_dirs) |
|
else: |
|
raise NotImplementedError |
|
|
|
if len(dataset_paths) > 1: |
|
print(f'{len(dataset_paths)} datasets requested: {dataset_paths}') |
|
dataset = torch.utils.data.ConcatDataset([ |
|
CsvDataset(path, image_dir, tokenizer, 'image', |
|
'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision, |
|
image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx) |
|
for (path, image_dir) in zip(dataset_paths, image_data_dirs)]) |
|
elif len(dataset_paths) == 1: |
|
dataset = CsvDataset(dataset_paths[0], image_data_dirs[0], tokenizer, 'image', |
|
'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision, |
|
image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx) |
|
else: |
|
raise ValueError(f'There should be at least one valid dataset, got train={args.dataset}, val={args.val_dataset} instead.') |
|
return dataset |
|
|
|
|
|
class CsvDataset(Dataset): |
|
def __init__(self, input_filename, base_image_dir, tokenizer, img_key, |
|
caption_key, feature_extractor_model: str, |
|
train: bool = True, max_len: int = 32, sep="\t", precision: str = 'fp32', |
|
image_size: int = 224, retrieval_token_idx: int = -1): |
|
logging.debug(f'Loading tsv data from {input_filename}.') |
|
df = pd.read_csv(input_filename, sep=sep) |
|
|
|
self.base_image_dir = base_image_dir |
|
self.images = df[img_key].tolist() |
|
self.captions = df[caption_key].tolist() |
|
assert len(self.images) == len(self.captions) |
|
|
|
self.feature_extractor_model = feature_extractor_model |
|
self.feature_extractor = utils.get_feature_extractor_for_model( |
|
feature_extractor_model, image_size=image_size, train=False) |
|
self.image_size = image_size |
|
|
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
self.precision = precision |
|
self.retrieval_token_idx = retrieval_token_idx |
|
|
|
self.font = None |
|
|
|
logging.debug('Done loading data.') |
|
|
|
def __len__(self): |
|
return len(self.captions) |
|
|
|
def __getitem__(self, idx): |
|
while True: |
|
image_path = os.path.join(self.base_image_dir, str(self.images[idx])) |
|
caption = str(self.captions[idx]) |
|
|
|
try: |
|
img = Image.open(image_path) |
|
images = utils.get_pixel_values_for_model(self.feature_extractor, img) |
|
|
|
caption += '[RET]' |
|
tokenized_data = self.tokenizer( |
|
caption, |
|
return_tensors="pt", |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.max_len) |
|
tokens = tokenized_data.input_ids[0] |
|
|
|
caption_len = tokenized_data.attention_mask[0].sum() |
|
|
|
decode_caption = self.tokenizer.decode(tokens, skip_special_tokens=False) |
|
self.font = self.font or ImageFont.load_default() |
|
cap_img = utils.create_image_of_text(decode_caption.encode('ascii', 'ignore'), width=self.image_size, nrows=2, font=self.font) |
|
|
|
if tokens[-1] not in [self.retrieval_token_idx, self.tokenizer.pad_token_id]: |
|
tokens[-1] = self.retrieval_token_idx |
|
|
|
return image_path, images, cap_img, tokens, caption_len |
|
except Exception as e: |
|
print(f'Error reading {image_path} with caption {caption}: {e}') |
|
|
|
idx = np.random.randint(0, len(self)-1) |
|
|