Spaces:
Runtime error
Runtime error
import functools | |
import logging | |
import math | |
import random | |
import sys | |
from dataclasses import dataclass | |
from multiprocessing import Value | |
import time | |
import os | |
import numpy as np | |
import pickle as pkl | |
from open_flamingo.train.instruction_template import ( | |
VG_RELATION_TEMPLATES, | |
PISC_TEMPLATES, | |
) | |
import torch | |
import webdataset as wds | |
from PIL import Image | |
from torch.utils.data import DataLoader, IterableDataset, get_worker_info | |
from torch.utils.data.distributed import DistributedSampler | |
from webdataset.tariterators import ( | |
base_plus_ext, | |
tar_file_expander, | |
url_opener, | |
valid_sample, | |
) | |
from groundingdino.demo.caption_grounder import caption_grounder | |
from groundingdino.demo.inference_on_laion import add_loc_to_text | |
from groundingdino.demo.inference_on_laion import nms_without_score | |
from groundingdino.demo.inference_on_laion import calculate_iou | |
Image.MAX_IMAGE_PIXELS = 1000000000 | |
LAION2B_NUM_SAMPLE = 1500000000 | |
VQAV2_TRAIN_NUM_SAMPLE = 1828467 | |
VG_RELATION_BBOX_SIZE = 600 | |
REL_LABELS = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with'] | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
class ConcatDataset(IterableDataset): | |
def __init__( | |
self, dataset, max_length, | |
delimiter_id, pad_id=None, media_id=None, endofmedia_id=None, | |
image_embedding_size=-2, single=False, box_id=None, visual_id=None, | |
): | |
self.dataset = dataset | |
self.max_length = max_length | |
self.delimiter_id = torch.ones(1,1).long() * delimiter_id | |
if pad_id is not None: | |
self.pad_id = int(pad_id) | |
if media_id is not None: | |
self.media_id = torch.ones(1,1).long() * int(media_id) | |
if endofmedia_id is not None: | |
self.endofmedia_id = torch.ones(1,1).long() * int(endofmedia_id) | |
if image_embedding_size > 0: | |
logging.info(f"image_embedding_size: {image_embedding_size}") | |
self.image_embedding_size = image_embedding_size + 2 | |
self.single = single | |
self.box_id = box_id | |
self.visual_id = visual_id | |
def __iter__(self): | |
while True: | |
input_ids_list = [] | |
attention_mask_list = [] | |
image_list = [] | |
image_start_index_list = [] | |
added_bbox_list = [] | |
relations_list = [] | |
cnt = 0 | |
while cnt < self.max_length: | |
sample = next(self.dataset) | |
if len(sample) >= 4: | |
image = sample[0].unsqueeze(0) | |
input_ids = sample[1] | |
attention_mask = sample[2] | |
added_bbox = sample[3] | |
image_list.append(image) | |
added_bbox_list.append(added_bbox) | |
if len(sample) == 5: | |
relations_list.append(sample[4]) | |
else: | |
sample = sample[0] | |
input_ids = sample[0] | |
attention_mask = sample[1] | |
input_ids_list.append(input_ids) | |
attention_mask_list.append(attention_mask) | |
cnt += input_ids.shape[-1] | |
if self.single: | |
break | |
input_ids = torch.cat(input_ids_list, dim=-1)[0] | |
attention_mask = torch.cat(attention_mask_list, dim=-1)[0] | |
if not self.single: | |
input_ids = input_ids[:self.max_length] | |
attention_mask = attention_mask[:self.max_length] | |
# TODO: fix visual number not match | |
if len(image_list) != 0: | |
images = torch.cat(image_list, dim=0) | |
image_begin = (input_ids == self.media_id[0,0]).nonzero().view(-1) | |
image_end = (input_ids == self.endofmedia_id[0,0]).nonzero().view(-1) | |
if len(image_begin) != len(image_end): | |
assert len(image_begin) == len(image_end) + 1 | |
input_ids[image_begin[-1]:] = self.pad_id | |
attention_mask[image_begin[-1]:] = 0 | |
image_begin = image_begin[:-1] | |
eos_token_num = len((input_ids == self.delimiter_id[0,0]).nonzero().view(-1)) | |
if eos_token_num != len(image_begin) + 1: | |
input_ids[image_begin[-1]:] = self.pad_id | |
attention_mask[image_begin[-1]:] = 0 | |
image_begin = image_begin[:-1] | |
image_end = image_end[:-1] | |
images = images[:len(image_end)] | |
added_bbox_list = added_bbox_list[:len(image_end)] | |
relations_list = relations_list[:len(image_end)] | |
image_start_index_list = (image_begin + 1).tolist() | |
expand_list = added_bbox_list[0] | |
for x in added_bbox_list[1:]: | |
expand_list.extend(x) | |
yield images, len(images), image_start_index_list, input_ids, attention_mask, expand_list, relations_list | |
else: | |
yield input_ids, attention_mask | |
class SharedEpoch: | |
def __init__(self, epoch: int = 0): | |
self.shared_epoch = Value("i", epoch) | |
def set_value(self, epoch): | |
self.shared_epoch.value = epoch | |
def get_value(self): | |
return self.shared_epoch.value | |
class DataInfo: | |
dataloader: DataLoader | |
sampler: DistributedSampler = None | |
shared_epoch: SharedEpoch = None | |
def set_epoch(self, epoch): | |
if self.shared_epoch is not None: | |
self.shared_epoch.set_value(epoch) | |
if self.sampler is not None and isinstance(self.sampler, DistributedSampler): | |
self.sampler.set_epoch(epoch) | |
def filter_no_caption_or_no_image(sample): | |
return ("txt" in sample) and ( | |
"png" in sample or "jpg" in sample or "jpeg" in sample | |
) | |
def log_and_continue(exn): | |
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
if "ValueError" in repr(exn) or "KeyError" in repr(exn): # Avoid spamming logs with these | |
return True | |
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
return True | |
# DEBUG | |
# log_and_continue = None | |
# DEBUG | |
def group_by_keys_nothrow( | |
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None | |
): | |
"""Return function over iterator that groups key, value pairs into samples. | |
:param keys: function that splits the key into key and extension (base_plus_ext) | |
:param lcase: convert suffixes to lower case (Default value = True) | |
""" | |
current_sample = None | |
tar_idx = None | |
for filesample in data: | |
assert isinstance(filesample, dict) | |
current_tar_idx = filesample["__url__"].split("/")[-1].split(".")[0] | |
if current_tar_idx != tar_idx: | |
tar_idx = current_tar_idx | |
if "blip2_all_data_ground" in filesample["__url__"]: | |
relation_data_dir = os.path.join("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_relation", tar_idx) | |
missing_file = False | |
try: | |
data_info = pkl.load(open(os.path.join(relation_data_dir, "custom_data_info.pkl"), "rb")) | |
prediction = pkl.load(open(os.path.join(relation_data_dir, "custom_prediction.pkl"), "rb")) | |
idx_to_files = data_info["idx_to_files"] | |
ind_to_classes = data_info["ind_to_classes"] | |
ind_to_predicates = data_info["ind_to_predicates"] | |
files_to_idx = {x.split("#")[-1]: i for i, x in enumerate(idx_to_files)} | |
except: | |
missing_file = True | |
fname, value = filesample["fname"], filesample["data"] | |
prefix, suffix = keys(fname) | |
if prefix is None: | |
continue | |
if lcase: | |
suffix = suffix.lower() | |
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
if ( | |
current_sample is None | |
or prefix != current_sample["__key__"] | |
or suffix in current_sample | |
): | |
if valid_sample(current_sample): | |
yield current_sample | |
current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
if "blip2_all_data_ground" in filesample["__url__"] and not missing_file: | |
try: | |
idx = files_to_idx[prefix] | |
prediction[idx]["bbox"] = [np.array(bbox)/VG_RELATION_BBOX_SIZE for bbox in prediction[idx]["bbox"]] | |
current_sample["relation_data"] = prediction[idx] | |
except: | |
current_sample["relation_data"] = dict() | |
else: | |
current_sample["relation_data"] = dict() | |
if suffixes is None or suffix in suffixes: | |
current_sample[suffix] = value | |
if valid_sample(current_sample): | |
yield current_sample | |
def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander(streams, handler=handler) | |
samples = group_by_keys_nothrow(files, handler=handler) | |
return samples | |
def pytorch_worker_seed(increment=0): | |
"""get dataloader worker seed from pytorch""" | |
worker_info = get_worker_info() | |
if worker_info is not None: | |
# favour using the seed already created for pytorch dataloader workers if it exists | |
seed = worker_info.seed | |
if increment: | |
# space out seed increments so they can't overlap across workers in different iterations | |
seed += increment * max(1, worker_info.num_workers) | |
return seed | |
# fallback to wds rank based seed | |
return wds.utils.pytorch_worker_seed() | |
_SHARD_SHUFFLE_SIZE = 2000 | |
_SHARD_SHUFFLE_INITIAL = 500 | |
_SAMPLE_SHUFFLE_SIZE = 5000 | |
_SAMPLE_SHUFFLE_INITIAL = 1000 | |
class ResampledShards2(IterableDataset): | |
"""An iterable dataset yielding a list of urls.""" | |
def __init__( | |
self, | |
urls, | |
nshards=sys.maxsize, | |
worker_seed=None, | |
deterministic=False, | |
epoch=-1, | |
): | |
"""Sample shards from the shard list with replacement. | |
:param urls: a list of URLs as a Python list or brace notation string | |
""" | |
super().__init__() | |
urls = wds.shardlists.expand_urls(urls) | |
self.urls = urls | |
assert isinstance(self.urls[0], str) | |
self.nshards = nshards | |
self.rng = random.Random() | |
self.worker_seed = worker_seed | |
self.deterministic = deterministic | |
self.epoch = epoch | |
def __iter__(self): | |
"""Return an iterator over the shards.""" | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
if self.deterministic: | |
# reset seed w/ epoch if deterministic | |
if self.worker_seed is None: | |
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id | |
seed = pytorch_worker_seed(epoch) | |
else: | |
seed = self.worker_seed() + epoch | |
seed = seed + int(time.time()) | |
self.rng.seed(seed) | |
# logging.info(f"epoch: {epoch} seed: {seed}") | |
self.rng.shuffle(self.urls) | |
# logging.info(f"{len(self.urls)} | {self.urls[:2]}") | |
for url in self.urls: | |
# logging.info(f"{seed}: {url}") | |
yield dict(url=url) | |
def preprocess_image(sample, image_processor): | |
image = image_processor(sample) | |
return image | |
def preprocess_text(sample, tokenizer, max_length, single=False): | |
if not single: | |
text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True) | |
else: | |
text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True, padding='max_length') | |
return text["input_ids"], text["attention_mask"] | |
def preprocess_encoded_text(sample, tokenizer, max_length): | |
sample = sample.decode("utf-8") | |
return preprocess_text(sample, tokenizer, max_length=max_length) | |
def _merge_bbox_previsual(added_bbox_list): | |
bbox_list = [] | |
for bboxes in added_bbox_list: | |
x1 = bboxes[:, 0].min() | |
y1 = bboxes[:, 1].min() | |
x2 = bboxes[:, 2].max() | |
y2 = bboxes[:, 3].max() | |
bbox_list.append(torch.tensor([x1, y1, x2, y2], device=bboxes.device, dtype=bboxes.dtype).unsqueeze(0)) | |
return bbox_list | |
def _find_idx(text, subtext): | |
loc = 0 | |
locs = [] | |
while text.find(subtext, loc) != -1: | |
loc = text.find(subtext, loc) | |
locs.append(loc) | |
loc += len(subtext) | |
return locs | |
def preprocess_ground_caption(sample, image_processor, tokenizer, image_embedding_size, generator, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None, args=None): | |
assert max_length is not None | |
assert not single, "single is not supported for preprocess_ground_caption" | |
image, caption, logits_filt, boxes_filt, relation_data = sample | |
if len(logits_filt.shape) == 1 and logits_filt.shape[0] == 4 and len(boxes_filt.shape) == 1 and boxes_filt.shape[0] == 4: | |
raise NotImplementedError # lack relation data | |
return preprocess_visual_genome(sample=sample, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=image_embedding_size, prob_ground=prob_ground, single=single, use_format_v2=use_format_v2, add_visual_token=add_visual_token, max_length=max_length) | |
image = preprocess_image(image, image_processor=image_processor) | |
added_bbox = [] | |
if (prob_ground != 0 and random.random() <= prob_ground) or prob_ground == 1.0: | |
boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True) | |
caption, added_bbox = add_loc_to_text( | |
boxes_filt, pred_phrases, caption, | |
expand=args.expand, always_expand=args.longer_previsual, | |
) | |
visual_loc = [] | |
obj_loc = [] | |
endofobj_loc = [] | |
visual_token = "<|#visual#|>" | |
previsual_token = "<|#previsual#|>" | |
box_token = "<|#box#|>" | |
prebox_token = "<|#prebox#|>" | |
end_token = "<|#endofobject#|>" | |
object_token = "<|#object#|>" | |
end_of_attr_token = "<|#endofattr#|>" | |
preend_of_attr_token = "<|#preendofattr#|>" | |
visual_loc = _find_idx(caption, visual_token) | |
try: | |
if len(visual_loc) != len(added_bbox): | |
logging.warning(f"visual_loc: {visual_loc}") | |
logging.warning(f"added_bbox: {added_bbox}") | |
except: | |
pass | |
assert len(visual_loc) == len(added_bbox) | |
delta = 0 | |
for i, (loc, boxes) in enumerate(zip(visual_loc, added_bbox)): | |
loc += delta | |
boxes = nms_without_score(boxes) | |
added_bbox[i] = boxes | |
added_tokens = end_token + visual_token + box_token * len(boxes) + end_of_attr_token | |
caption = caption[:loc] + added_tokens + caption[len(visual_token) + loc:] | |
delta += len(added_tokens) - len(visual_token) | |
if use_format_v2: | |
merge_added_bbox = _merge_bbox_previsual(added_bbox) | |
# step 1: move <|#object#|> before the space char | |
while caption.find(f" {object_token}") != -1: | |
caption = caption.replace(f" {object_token}", f"{object_token} ") | |
# step 2: add <|#previsual#|> after <|#object#|> for 75% except the first object | |
i = 0 | |
II = -1 | |
if args.no_visual: | |
flag = False | |
delete_visual_prob = 10.0 | |
else: | |
flag = True | |
delete_visual_prob = 0.75 | |
while i < len(caption): | |
if caption[i: i + len(object_token)] == object_token: | |
II += 1 | |
if (not args.longer_previsual and not flag and random.random() < delete_visual_prob) or (args.longer_previsual and (flag or random.random() < delete_visual_prob)): | |
# delete visual and add previsual | |
visual_start_idx = caption.find(end_token, i+1) + len(end_token) | |
visual_end_idx = caption.find(end_of_attr_token, visual_start_idx+1) + len(end_of_attr_token) | |
caption = caption[:visual_start_idx] + caption[visual_end_idx:] | |
caption = caption[:i + len(object_token)] + previsual_token + prebox_token + preend_of_attr_token + caption[i + len(object_token):] | |
added_bbox[II] = merge_added_bbox[II] | |
i += 1 | |
flag = False | |
if args.no_previsual and args.no_visual: | |
caption = caption.replace(previsual_token, "").replace(prebox_token, "").replace(preend_of_attr_token, "") | |
added_bbox = [] | |
caption = caption.replace(preend_of_attr_token, object_token).replace(end_of_attr_token, end_token) | |
if args.roi_align: | |
i = 0 | |
pad_num = args.roi_output_size ** 2 - 1 | |
while i < len(caption): | |
if caption[i: i + len(prebox_token)] == prebox_token: | |
caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:] | |
i += len(tokenizer.pad_token) * pad_num + len(prebox_token) | |
elif caption[i: i + len(box_token)] == box_token: | |
caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:] | |
i += len(tokenizer.pad_token) * pad_num + len(box_token) | |
i += 1 | |
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption | |
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length) | |
relations = [] | |
if args.only_grounded_sample and "<|#visual#|>" not in caption: | |
raise ValueError | |
return image, input_ids, attention_mask, added_bbox, relations | |
def preprocess_visual_genome(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None): | |
assert max_length is not None | |
assert not single, "single is not supported for preprocess_ground_caption" | |
image, caption, xyxy, _ = sample | |
image = preprocess_image(image, image_processor=image_processor) | |
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|><|#object#|>" + caption.strip() + "<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>" | |
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length) | |
added_bbox = [torch.tensor(np.expand_dims(xyxy, 0).astype(np.float32) / 224)] | |
return image, input_ids, attention_mask, added_bbox | |
special_predicate = [ | |
"and", | |
"has", | |
"says", | |
"wears", | |
] | |
original_predicate = { | |
"and": "and", | |
"has": "have", | |
"says": "say", | |
"wears": "wear", | |
} | |
def generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation): | |
if relation in ["and", "of"]: | |
id = 0 | |
else: | |
id = random.choice(range(len(VG_RELATION_TEMPLATES))) | |
text = VG_RELATION_TEMPLATES[id].format(nameA=nameA, nameB=nameB, relation=relation, use_is="is" if relation not in special_predicate else "", is_or_does="is" if relation not in special_predicate else "does", relation_do=relation if relation not in special_predicate else original_predicate[relation]) | |
if id in [0]: | |
added_bbox = [ | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
] | |
elif id in [1]: | |
added_bbox = [ | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
] | |
elif id in [2]: | |
added_bbox = [ | |
torch.tensor([boxA]), | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
] | |
elif id in [3]: | |
added_bbox = [ | |
torch.tensor([boxB]), | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
] | |
elif id in [4]: | |
added_bbox = [ | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
] | |
elif id in [5]: | |
added_bbox = [ | |
torch.tensor([boxB]), | |
torch.tensor([boxA]), | |
] | |
else: | |
raise NotImplementedError | |
return text, added_bbox | |
def generate_pisc_sample(boxA, boxB, relation): | |
id = random.choice(range(len(PISC_TEMPLATES))) | |
text = PISC_TEMPLATES[id].format(relation=relation) | |
if id in [0]: | |
if random.random() < 0.5: | |
added_bbox = [ | |
torch.tensor([boxA]), | |
torch.tensor([boxB]), | |
] | |
else: | |
added_bbox = [ | |
torch.tensor([boxB]), | |
torch.tensor([boxA]), | |
] | |
elif id in [1]: | |
if random.random() < 0.5: | |
added_bbox = [torch.tensor([boxA, boxB])] | |
else: | |
added_bbox = [torch.tensor([boxB, boxA])] | |
return text, added_bbox | |
def preprocess_instruct(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None): | |
image_path, dataset, data = sample | |
image = Image.open(image_path) | |
size = image_processor.transforms[0].size | |
image = image.resize((size, size)) | |
if dataset == "pisc_relation_split": | |
boxA = data[0] | |
boxB = data[1] | |
relation = data[2] | |
text, added_bbox = generate_pisc_sample(boxA, boxB, relation) | |
# import cv2 | |
# boxA *= size | |
# boxB *= size | |
# open_cv_image = np.array(image) | |
# open_cv_image = open_cv_image[:, :, ::-1].copy() | |
# open_cv_image = cv2.rectangle(open_cv_image, boxA[:2].astype(int), boxA[2:].astype(int), (255, 0, 0), 2) | |
# open_cv_image = cv2.rectangle(open_cv_image, boxB[:2].astype(int), boxB[2:].astype(int), (0, 255, 0), 2) | |
# cv2.imwrite("output.jpg", open_cv_image) | |
# import pdb; pdb.set_trace() | |
elif dataset == "vg_relation": | |
boxA = data[0][0] | |
nameA = data[0][1] | |
boxB = data[1][0] | |
nameB = data[1][1] | |
relation = data[2] | |
text, added_bbox = generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation) | |
image = preprocess_image(image, image_processor=image_processor) | |
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + text + tokenizer.eos_token | |
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=True) | |
# return image, input_ids, attention_mask, added_bbox | |
images = image.unsqueeze(0) | |
image_start_index_list = [2] | |
return images, len(images), image_start_index_list, input_ids, attention_mask, added_bbox | |
def preprocess_caption(sample, image_processor, tokenizer, image_embedding_size, max_length, single=False): | |
image, caption = sample | |
caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption | |
image = preprocess_image(image, image_processor=image_processor) | |
input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=single) | |
return image, input_ids, attention_mask | |
def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False): | |
input_shards = args.pile_shards | |
assert input_shards is not None | |
resampled = getattr(args, "dataset_resampled", False) | |
assert resampled, "turn on dataset_resampled to allow infinite stream of samples" | |
# create a shared epoch store to sync epoch to dataloader worker proc | |
shared_epoch = SharedEpoch(epoch=epoch) | |
preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer, max_length=args.max_length) | |
pipeline = [ | |
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), | |
tarfile_to_samples_nothrow, | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
wds.to_tuple("txt", handler=log_and_continue), | |
wds.map_tuple( | |
preprocess_text_fn, handler=log_and_continue | |
), | |
] | |
# with_epoch(sys.maxsize) will give us an infinite sample stream | |
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) | |
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1] | |
dataset = ConcatDataset(iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id) | |
def text_collate_fn(items): | |
try: | |
input_ids = torch.cat([x[0].unsqueeze(0) for x in items], dim=0) | |
attention_mask = torch.cat([x[1].unsqueeze(0) for x in items], dim=0) | |
return input_ids, attention_mask | |
except: | |
return None, None | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=args.batch_size_pile, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=False, | |
collate_fn=text_collate_fn, | |
) | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
# FIXME: | |
# modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433 | |
# combine_tensors=True to combine_tensors=False | |
def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False): | |
input_shards = args.laion_shards | |
assert input_shards is not None | |
resampled = getattr(args, "dataset_resampled", False) | |
assert resampled, "turn on dataset_resampled to allow infinite stream of samples" | |
# create a shared epoch store to sync epoch to dataloader worker proc | |
shared_epoch = SharedEpoch(epoch=epoch) | |
generator = caption_grounder( | |
config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", | |
checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth", | |
cpu_only=True, | |
# box_threshold=0.5, text_threshold=0.3, | |
) | |
preprocess_ground_caption_fn = functools.partial( | |
preprocess_ground_caption, image_processor=image_processor, tokenizer=tokenizer, | |
image_embedding_size=args.vis_embed_size, single=args.single, generator=generator, | |
prob_ground=args.prob_ground, use_format_v2=args.use_format_v2, | |
add_visual_token=args.add_visual_token, max_length=args.max_length, | |
args=args, | |
) | |
pipeline = [ | |
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), | |
tarfile_to_samples_nothrow, | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", partial=True, handler=log_and_continue), | |
wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", "relation_data", handler=log_and_continue), | |
wds.map( | |
preprocess_ground_caption_fn, handler=log_and_continue | |
), | |
] | |
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) | |
# for sample in dataset: | |
# print(tokenizer.decode(sample[1][0]).replace("<PAD>", "")) | |
# DEBUG | |
# dataset = wds.DataPipeline(*pipeline) | |
# from tqdm import tqdm | |
# for sample in tqdm(dataset): | |
# nn = 0 | |
# for x in sample[1][0]: | |
# if x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]: | |
# nn += 1 | |
# if x == tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]: | |
# nn -= 1 | |
# if nn not in [0, 1]: | |
# print(tokenizer.decode(sample[1][0]).replace("<PAD>", "")) | |
# import pdb; pdb.set_trace() | |
# if nn != 0: | |
# print(tokenizer.decode(sample[1][0]).replace("<PAD>", "")) | |
# import pdb; pdb.set_trace() | |
# from groundingdino.demo.inference_on_laion import OBJ_LENGTHS | |
# # import pdb; pdb.set_trace() | |
# print(sum(OBJ_LENGTHS) / len(OBJ_LENGTHS)) | |
# exit() | |
# DEBUG | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
box_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
visual_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
dataset = ConcatDataset( | |
iter(dataset), max_length=args.max_length, | |
delimiter_id=delimiter_id, | |
pad_id=tokenizer.pad_token_id, | |
media_id=media_token_id, | |
endofmedia_id=endofmedia_token_id, | |
box_id=box_id, | |
visual_id=visual_id, | |
image_embedding_size=args.vis_embed_size, | |
single=args.single, | |
) | |
def image_collate_fn(items): | |
images = torch.cat([x[0] for x in items], dim=0) | |
image_nums = [x[1] for x in items] | |
image_start_index_list = [x[2] for x in items] | |
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0) | |
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0) | |
added_bbox_list = [x[5] for x in items] | |
expand_list = added_bbox_list[0] | |
for x in added_bbox_list[1:]: | |
expand_list.extend(x) | |
relations_list = [x[6] for x in items] | |
return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list, relations_list | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=args.batch_size_laion, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=False, | |
collate_fn=image_collate_fn, | |
) | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size_laion * args.world_size | |
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size) | |
dataloader.num_batches = num_batches | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_image_text_pair_dataset(args, image_processor, tokenizer, epoch=0, floor=False): | |
input_shards = args.laion_shards | |
assert input_shards is not None | |
resampled = getattr(args, "dataset_resampled", False) | |
assert resampled, "turn on dataset_resampled to allow infinite stream of samples" | |
# create a shared epoch store to sync epoch to dataloader worker proc | |
shared_epoch = SharedEpoch(epoch=epoch) | |
preprocess_caption_fn = functools.partial( | |
preprocess_caption, image_processor=image_processor, tokenizer=tokenizer, | |
image_embedding_size=args.vis_embed_size, single=args.single, | |
max_length=args.max_length, | |
) | |
pipeline = [ | |
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), | |
tarfile_to_samples_nothrow, | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue), | |
wds.map( | |
preprocess_caption_fn, handler=log_and_continue | |
), | |
] | |
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
dataset = ConcatDataset( | |
iter(dataset), max_length=args.max_length, | |
delimiter_id=delimiter_id, | |
pad_id=tokenizer.pad_token_id, | |
media_id=media_token_id, | |
endofmedia_id=endofmedia_token_id, | |
image_embedding_size=args.vis_embed_size, | |
single=args.single, | |
) | |
def image_collate_fn(items): | |
images = torch.cat([x[0] for x in items], dim=0) | |
image_nums = [x[1] for x in items] | |
image_start_index_list = [x[2] for x in items] | |
input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0) | |
attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0) | |
return images, image_nums, image_start_index_list, input_ids, attention_mask | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=args.batch_size_laion, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=False, | |
collate_fn=image_collate_fn, | |
) | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size_laion * args.world_size | |
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size) | |
dataloader.num_batches = num_batches | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_instruct_dataset(args, image_processor, tokenizer, epoch=0, floor=False): | |
input_shards = args.laion_shards | |
assert input_shards is not None | |
resampled = getattr(args, "dataset_resampled", False) | |
assert resampled, "turn on dataset_resampled to allow infinite stream of samples" | |
# create a shared epoch store to sync epoch to dataloader worker proc | |
shared_epoch = SharedEpoch(epoch=epoch) | |
preprocess_instruct_fn = functools.partial( | |
preprocess_instruct, image_processor=image_processor, tokenizer=tokenizer, | |
image_embedding_size=args.vis_embed_size, | |
max_length=args.max_length, | |
) | |
pipeline = [ | |
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), | |
tarfile_to_samples_nothrow, | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
wds.decode(partial=True), | |
wds.to_tuple("image_path.txt", "dataset.txt", "data.pyd", handler=log_and_continue), | |
wds.map( | |
preprocess_instruct_fn, handler=log_and_continue | |
), | |
] | |
dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) | |
def image_collate_fn(items): | |
images = torch.cat([x[0] for x in items], dim=0) | |
image_nums = [x[1] for x in items] | |
image_start_index_list = [x[2] for x in items] | |
input_ids = torch.cat([x[3] for x in items], dim=0) | |
attention_mask = torch.cat([x[4] for x in items], dim=0) | |
added_bbox_list = [x[5] for x in items] | |
expand_list = added_bbox_list[0] | |
for x in added_bbox_list[1:]: | |
expand_list.extend(x) | |
return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=args.batch_size_laion, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=False, | |
collate_fn=image_collate_fn, | |
) | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size_laion * args.world_size | |
num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size) | |
dataloader.num_batches = num_batches | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_dataset_fn(dataset_type): | |
if dataset_type == "mmc4": | |
raise NotImplementedError | |
elif dataset_type == "pile": | |
return get_pile_dataset | |
elif dataset_type == "ground_image_text": | |
return get_ground_laion_dataset | |
elif dataset_type == "image_text": | |
return get_image_text_pair_dataset | |
elif dataset_type == "vqav2": | |
raise NotImplementedError | |
elif dataset_type == "instruct": | |
return get_instruct_dataset | |
else: | |
raise ValueError(f"Unsupported dataset type: {dataset_type}") | |
def get_data(args, image_processor, tokenizer, dataset_type, epoch=0): | |
return get_dataset_fn(dataset_type)( | |
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer | |
) | |