File size: 5,325 Bytes
f76d30f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import logging
import json
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
import base64
from io import BytesIO
import torch
import lmdb
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
import torchvision.datasets as datasets
from clip import tokenize
def _convert_to_rgb(image):
return image.convert('RGB')
def _preprocess_text(text):
# adapt the text to Chinese BERT vocab
text = text.lower().replace("“", "\"").replace("”", "\"")
return text
class EvalTxtDataset(Dataset):
def __init__(self, jsonl_filename, max_txt_length=24):
assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename)
logging.debug(f'Loading jsonl data from {jsonl_filename}.')
self.texts = []
with open(jsonl_filename, "r", encoding="utf-8") as fin:
for line in fin:
obj = json.loads(line.strip())
text_id = obj['text_id']
text = obj['text']
self.texts.append((text_id, text))
logging.debug(f'Finished loading jsonl data from {jsonl_filename}.')
self.max_txt_length = max_txt_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text_id, text = self.texts[idx]
text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0]
return text_id, text
class EvalImgDataset(Dataset):
def __init__(self, lmdb_imgs, resolution=224):
assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs)
logging.debug(f'Loading image LMDB from {lmdb_imgs}.')
self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
self.txn_imgs = self.env_imgs.begin(buffers=True)
self.cursor_imgs = self.txn_imgs.cursor()
self.iter_imgs = iter(self.cursor_imgs)
self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8'))
logging.info("The specified LMDB directory contains {} images.".format(self.number_images))
self.transform = self._build_transform(resolution)
def _build_transform(self, resolution):
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
return Compose([
Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
])
def __len__(self):
return self.number_images
def __getitem__(self, idx):
img_id, image_b64 = next(self.iter_imgs)
if img_id == b"num_images":
img_id, image_b64 = next(self.iter_imgs)
img_id = img_id.tobytes()
image_b64 = image_b64.tobytes()
img_id = int(img_id.decode(encoding="utf8", errors="ignore"))
image_b64 = image_b64.decode(encoding="utf8", errors="ignore")
image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) # already resized
image = self.transform(image)
return img_id, image
@dataclass
class DataInfo:
dataloader: DataLoader
sampler: DistributedSampler
def get_eval_txt_dataset(args, max_txt_length=24):
input_filename = args.text_data
dataset = EvalTxtDataset(
input_filename,
max_txt_length=max_txt_length)
num_samples = len(dataset)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=args.text_batch_size,
num_workers=0,
pin_memory=True,
sampler=sampler,
drop_last=False,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def fetch_resolution(vision_model):
# fetch the resolution from the vision model config
vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json"
with open(vision_model_config_file, 'r') as fv:
model_info = json.load(fv)
return model_info["image_resolution"]
def get_eval_img_dataset(args):
lmdb_imgs = args.image_data
dataset = EvalImgDataset(
lmdb_imgs, resolution=fetch_resolution(args.vision_model))
num_samples = len(dataset)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=args.img_batch_size,
num_workers=0,
pin_memory=True,
sampler=sampler,
drop_last=False,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_zeroshot_dataset(args, preprocess_fn):
dataset = datasets.ImageFolder(args.datapath, transform=preprocess_fn)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.img_batch_size,
num_workers=args.num_workers,
sampler=None,
)
return DataInfo(dataloader, None) |