import functools import logging import math import random import sys from dataclasses import dataclass from multiprocessing import Value import time import os import numpy as np import pickle as pkl from open_flamingo.train.instruction_template import ( VG_RELATION_TEMPLATES, PISC_TEMPLATES, ) import torch import webdataset as wds from PIL import Image from torch.utils.data import DataLoader, IterableDataset, get_worker_info from torch.utils.data.distributed import DistributedSampler from webdataset.tariterators import ( base_plus_ext, tar_file_expander, url_opener, valid_sample, ) from groundingdino.demo.caption_grounder import caption_grounder from groundingdino.demo.inference_on_laion import add_loc_to_text from groundingdino.demo.inference_on_laion import nms_without_score from groundingdino.demo.inference_on_laion import calculate_iou Image.MAX_IMAGE_PIXELS = 1000000000 LAION2B_NUM_SAMPLE = 1500000000 VQAV2_TRAIN_NUM_SAMPLE = 1828467 VG_RELATION_BBOX_SIZE = 600 REL_LABELS = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with'] try: import horovod.torch as hvd except ImportError: hvd = None class ConcatDataset(IterableDataset): def __init__( self, dataset, max_length, delimiter_id, pad_id=None, media_id=None, endofmedia_id=None, image_embedding_size=-2, single=False, box_id=None, visual_id=None, ): self.dataset = dataset self.max_length = max_length self.delimiter_id = torch.ones(1,1).long() * delimiter_id if pad_id is not None: self.pad_id = int(pad_id) if media_id is not None: self.media_id = torch.ones(1,1).long() * int(media_id) if endofmedia_id is not None: self.endofmedia_id = torch.ones(1,1).long() * int(endofmedia_id) if image_embedding_size > 0: logging.info(f"image_embedding_size: {image_embedding_size}") self.image_embedding_size = image_embedding_size + 2 self.single = single self.box_id = box_id self.visual_id = visual_id def __iter__(self): while True: input_ids_list = [] attention_mask_list = [] image_list = [] image_start_index_list = [] added_bbox_list = [] relations_list = [] cnt = 0 while cnt < self.max_length: sample = next(self.dataset) if len(sample) >= 4: image = sample[0].unsqueeze(0) input_ids = sample[1] attention_mask = sample[2] added_bbox = sample[3] image_list.append(image) added_bbox_list.append(added_bbox) if len(sample) == 5: relations_list.append(sample[4]) else: sample = sample[0] input_ids = sample[0] attention_mask = sample[1] input_ids_list.append(input_ids) attention_mask_list.append(attention_mask) cnt += input_ids.shape[-1] if self.single: break input_ids = torch.cat(input_ids_list, dim=-1)[0] attention_mask = torch.cat(attention_mask_list, dim=-1)[0] if not self.single: input_ids = input_ids[:self.max_length] attention_mask = attention_mask[:self.max_length] # TODO: fix visual number not match if len(image_list) != 0: images = torch.cat(image_list, dim=0) image_begin = (input_ids == self.media_id[0,0]).nonzero().view(-1) image_end = (input_ids == self.endofmedia_id[0,0]).nonzero().view(-1) if len(image_begin) != len(image_end): assert len(image_begin) == len(image_end) + 1 input_ids[image_begin[-1]:] = self.pad_id attention_mask[image_begin[-1]:] = 0 image_begin = image_begin[:-1] eos_token_num = len((input_ids == self.delimiter_id[0,0]).nonzero().view(-1)) if eos_token_num != len(image_begin) + 1: input_ids[image_begin[-1]:] = self.pad_id attention_mask[image_begin[-1]:] = 0 image_begin = image_begin[:-1] image_end = image_end[:-1] images = images[:len(image_end)] added_bbox_list = added_bbox_list[:len(image_end)] relations_list = relations_list[:len(image_end)] image_start_index_list = (image_begin + 1).tolist() expand_list = added_bbox_list[0] for x in added_bbox_list[1:]: expand_list.extend(x) yield images, len(images), image_start_index_list, input_ids, attention_mask, expand_list, relations_list else: yield input_ids, attention_mask class SharedEpoch: def __init__(self, epoch: int = 0): self.shared_epoch = Value("i", epoch) def set_value(self, epoch): self.shared_epoch.value = epoch def get_value(self): return self.shared_epoch.value @dataclass class DataInfo: dataloader: DataLoader sampler: DistributedSampler = None shared_epoch: SharedEpoch = None def set_epoch(self, epoch): if self.shared_epoch is not None: self.shared_epoch.set_value(epoch) if self.sampler is not None and isinstance(self.sampler, DistributedSampler): self.sampler.set_epoch(epoch) def filter_no_caption_or_no_image(sample): return ("txt" in sample) and ( "png" in sample or "jpg" in sample or "jpeg" in sample ) def log_and_continue(exn): """Call in an exception handler to ignore any exception, issue a warning, and continue.""" if "ValueError" in repr(exn) or "KeyError" in repr(exn): # Avoid spamming logs with these return True logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True # DEBUG # log_and_continue = None # DEBUG def group_by_keys_nothrow( data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None ): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to lower case (Default value = True) """ current_sample = None tar_idx = None for filesample in data: assert isinstance(filesample, dict) current_tar_idx = filesample["__url__"].split("/")[-1].split(".")[0] if current_tar_idx != tar_idx: tar_idx = current_tar_idx if "blip2_all_data_ground" in filesample["__url__"]: relation_data_dir = os.path.join("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_relation", tar_idx) missing_file = False try: data_info = pkl.load(open(os.path.join(relation_data_dir, "custom_data_info.pkl"), "rb")) prediction = pkl.load(open(os.path.join(relation_data_dir, "custom_prediction.pkl"), "rb")) idx_to_files = data_info["idx_to_files"] ind_to_classes = data_info["ind_to_classes"] ind_to_predicates = data_info["ind_to_predicates"] files_to_idx = {x.split("#")[-1]: i for i, x in enumerate(idx_to_files)} except: missing_file = True fname, value = filesample["fname"], filesample["data"] prefix, suffix = keys(fname) if prefix is None: continue if lcase: suffix = suffix.lower() # FIXME webdataset version throws if suffix in current_sample, but we have a potential for # this happening in the current LAION400m dataset if a tar ends with same prefix as the next # begins, rare, but can happen since prefix aren't unique across tar files in that dataset if ( current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample ): if valid_sample(current_sample): yield current_sample current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) if "blip2_all_data_ground" in filesample["__url__"] and not missing_file: try: idx = files_to_idx[prefix] prediction[idx]["bbox"] = [np.array(bbox)/VG_RELATION_BBOX_SIZE for bbox in prediction[idx]["bbox"]] current_sample["relation_data"] = prediction[idx] except: current_sample["relation_data"] = dict() else: current_sample["relation_data"] = dict() if suffixes is None or suffix in suffixes: current_sample[suffix] = value if valid_sample(current_sample): yield current_sample def tarfile_to_samples_nothrow(src, handler=log_and_continue): # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw streams = url_opener(src, handler=handler) files = tar_file_expander(streams, handler=handler) samples = group_by_keys_nothrow(files, handler=handler) return samples def pytorch_worker_seed(increment=0): """get dataloader worker seed from pytorch""" worker_info = get_worker_info() if worker_info is not None: # favour using the seed already created for pytorch dataloader workers if it exists seed = worker_info.seed if increment: # space out seed increments so they can't overlap across workers in different iterations seed += increment * max(1, worker_info.num_workers) return seed # fallback to wds rank based seed return wds.utils.pytorch_worker_seed() _SHARD_SHUFFLE_SIZE = 2000 _SHARD_SHUFFLE_INITIAL = 500 _SAMPLE_SHUFFLE_SIZE = 5000 _SAMPLE_SHUFFLE_INITIAL = 1000 class ResampledShards2(IterableDataset): """An iterable dataset yielding a list of urls.""" def __init__( self, urls, nshards=sys.maxsize, worker_seed=None, deterministic=False, epoch=-1, ): """Sample shards from the shard list with replacement. :param urls: a list of URLs as a Python list or brace notation string """ super().__init__() urls = wds.shardlists.expand_urls(urls) self.urls = urls assert isinstance(self.urls[0], str) self.nshards = nshards self.rng = random.Random() self.worker_seed = worker_seed self.deterministic = deterministic self.epoch = epoch def __iter__(self): """Return an iterator over the shards.""" if isinstance(self.epoch, SharedEpoch): epoch = self.epoch.get_value() else: # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) # situation as different workers may wrap at different times (or not at all). self.epoch += 1 epoch = self.epoch if self.deterministic: # reset seed w/ epoch if deterministic if self.worker_seed is None: # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id seed = pytorch_worker_seed(epoch) else: seed = self.worker_seed() + epoch seed = seed + int(time.time()) self.rng.seed(seed) # logging.info(f"epoch: {epoch} seed: {seed}") self.rng.shuffle(self.urls) # logging.info(f"{len(self.urls)} | {self.urls[:2]}") for url in self.urls: # logging.info(f"{seed}: {url}") yield dict(url=url) def preprocess_image(sample, image_processor): image = image_processor(sample) return image def preprocess_text(sample, tokenizer, max_length, single=False): if not single: text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True) else: text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True, padding='max_length') return text["input_ids"], text["attention_mask"] def preprocess_encoded_text(sample, tokenizer, max_length): sample = sample.decode("utf-8") return preprocess_text(sample, tokenizer, max_length=max_length) def _merge_bbox_previsual(added_bbox_list): bbox_list = [] for bboxes in added_bbox_list: x1 = bboxes[:, 0].min() y1 = bboxes[:, 1].min() x2 = bboxes[:, 2].max() y2 = bboxes[:, 3].max() bbox_list.append(torch.tensor([x1, y1, x2, y2], device=bboxes.device, dtype=bboxes.dtype).unsqueeze(0)) return bbox_list def _find_idx(text, subtext): loc = 0 locs = [] while text.find(subtext, loc) != -1: loc = text.find(subtext, loc) locs.append(loc) loc += len(subtext) return locs def preprocess_ground_caption(sample, image_processor, tokenizer, image_embedding_size, generator, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None, args=None): assert max_length is not None assert not single, "single is not supported for preprocess_ground_caption" image, caption, logits_filt, boxes_filt, relation_data = sample if len(logits_filt.shape) == 1 and logits_filt.shape[0] == 4 and len(boxes_filt.shape) == 1 and boxes_filt.shape[0] == 4: raise NotImplementedError # lack relation data return preprocess_visual_genome(sample=sample, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=image_embedding_size, prob_ground=prob_ground, single=single, use_format_v2=use_format_v2, add_visual_token=add_visual_token, max_length=max_length) image = preprocess_image(image, image_processor=image_processor) added_bbox = [] if (prob_ground != 0 and random.random() <= prob_ground) or prob_ground == 1.0: boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True) caption, added_bbox = add_loc_to_text( boxes_filt, pred_phrases, caption, expand=args.expand, always_expand=args.longer_previsual, ) visual_loc = [] obj_loc = [] endofobj_loc = [] visual_token = "<|#visual#|>" previsual_token = "<|#previsual#|>" box_token = "<|#box#|>" prebox_token = "<|#prebox#|>" end_token = "<|#endofobject#|>" object_token = "<|#object#|>" end_of_attr_token = "<|#endofattr#|>" preend_of_attr_token = "<|#preendofattr#|>" visual_loc = _find_idx(caption, visual_token) try: if len(visual_loc) != len(added_bbox): logging.warning(f"visual_loc: {visual_loc}") logging.warning(f"added_bbox: {added_bbox}") except: pass assert len(visual_loc) == len(added_bbox) delta = 0 for i, (loc, boxes) in enumerate(zip(visual_loc, added_bbox)): loc += delta boxes = nms_without_score(boxes) added_bbox[i] = boxes added_tokens = end_token + visual_token + box_token * len(boxes) + end_of_attr_token caption = caption[:loc] + added_tokens + caption[len(visual_token) + loc:] delta += len(added_tokens) - len(visual_token) if use_format_v2: merge_added_bbox = _merge_bbox_previsual(added_bbox) # step 1: move <|#object#|> before the space char while caption.find(f" {object_token}") != -1: caption = caption.replace(f" {object_token}", f"{object_token} ") # step 2: add <|#previsual#|> after <|#object#|> for 75% except the first object i = 0 II = -1 if args.no_visual: flag = False delete_visual_prob = 10.0 else: flag = True delete_visual_prob = 0.75 while i < len(caption): if caption[i: i + len(object_token)] == object_token: II += 1 if (not args.longer_previsual and not flag and random.random() < delete_visual_prob) or (args.longer_previsual and (flag or random.random() < delete_visual_prob)): # delete visual and add previsual visual_start_idx = caption.find(end_token, i+1) + len(end_token) visual_end_idx = caption.find(end_of_attr_token, visual_start_idx+1) + len(end_of_attr_token) caption = caption[:visual_start_idx] + caption[visual_end_idx:] caption = caption[:i + len(object_token)] + previsual_token + prebox_token + preend_of_attr_token + caption[i + len(object_token):] added_bbox[II] = merge_added_bbox[II] i += 1 flag = False if args.no_previsual and args.no_visual: caption = caption.replace(previsual_token, "").replace(prebox_token, "").replace(preend_of_attr_token, "") added_bbox = [] caption = caption.replace(preend_of_attr_token, object_token).replace(end_of_attr_token, end_token) if args.roi_align: i = 0 pad_num = args.roi_output_size ** 2 - 1 while i < len(caption): if caption[i: i + len(prebox_token)] == prebox_token: caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:] i += len(tokenizer.pad_token) * pad_num + len(prebox_token) elif caption[i: i + len(box_token)] == box_token: caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:] i += len(tokenizer.pad_token) * pad_num + len(box_token) i += 1 caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length) relations = [] if args.only_grounded_sample and "<|#visual#|>" not in caption: raise ValueError return image, input_ids, attention_mask, added_bbox, relations def preprocess_visual_genome(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None): assert max_length is not None assert not single, "single is not supported for preprocess_ground_caption" image, caption, xyxy, _ = sample image = preprocess_image(image, image_processor=image_processor) caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|><|#object#|>" + caption.strip() + "<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>" input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length) added_bbox = [torch.tensor(np.expand_dims(xyxy, 0).astype(np.float32) / 224)] return image, input_ids, attention_mask, added_bbox special_predicate = [ "and", "has", "says", "wears", ] original_predicate = { "and": "and", "has": "have", "says": "say", "wears": "wear", } def generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation): if relation in ["and", "of"]: id = 0 else: id = random.choice(range(len(VG_RELATION_TEMPLATES))) text = VG_RELATION_TEMPLATES[id].format(nameA=nameA, nameB=nameB, relation=relation, use_is="is" if relation not in special_predicate else "", is_or_does="is" if relation not in special_predicate else "does", relation_do=relation if relation not in special_predicate else original_predicate[relation]) if id in [0]: added_bbox = [ torch.tensor([boxA]), torch.tensor([boxB]), ] elif id in [1]: added_bbox = [ torch.tensor([boxA]), torch.tensor([boxB]), torch.tensor([boxA]), torch.tensor([boxB]), ] elif id in [2]: added_bbox = [ torch.tensor([boxA]), torch.tensor([boxA]), torch.tensor([boxB]), ] elif id in [3]: added_bbox = [ torch.tensor([boxB]), torch.tensor([boxA]), torch.tensor([boxB]), ] elif id in [4]: added_bbox = [ torch.tensor([boxA]), torch.tensor([boxB]), ] elif id in [5]: added_bbox = [ torch.tensor([boxB]), torch.tensor([boxA]), ] else: raise NotImplementedError return text, added_bbox def generate_pisc_sample(boxA, boxB, relation): id = random.choice(range(len(PISC_TEMPLATES))) text = PISC_TEMPLATES[id].format(relation=relation) if id in [0]: if random.random() < 0.5: added_bbox = [ torch.tensor([boxA]), torch.tensor([boxB]), ] else: added_bbox = [ torch.tensor([boxB]), torch.tensor([boxA]), ] elif id in [1]: if random.random() < 0.5: added_bbox = [torch.tensor([boxA, boxB])] else: added_bbox = [torch.tensor([boxB, boxA])] return text, added_bbox def preprocess_instruct(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None): image_path, dataset, data = sample image = Image.open(image_path) size = image_processor.transforms[0].size image = image.resize((size, size)) if dataset == "pisc_relation_split": boxA = data[0] boxB = data[1] relation = data[2] text, added_bbox = generate_pisc_sample(boxA, boxB, relation) # import cv2 # boxA *= size # boxB *= size # open_cv_image = np.array(image) # open_cv_image = open_cv_image[:, :, ::-1].copy() # open_cv_image = cv2.rectangle(open_cv_image, boxA[:2].astype(int), boxA[2:].astype(int), (255, 0, 0), 2) # open_cv_image = cv2.rectangle(open_cv_image, boxB[:2].astype(int), boxB[2:].astype(int), (0, 255, 0), 2) # cv2.imwrite("output.jpg", open_cv_image) # import pdb; pdb.set_trace() elif dataset == "vg_relation": boxA = data[0][0] nameA = data[0][1] boxB = data[1][0] nameB = data[1][1] relation = data[2] text, added_bbox = generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation) image = preprocess_image(image, image_processor=image_processor) caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + text + tokenizer.eos_token input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=True) # return image, input_ids, attention_mask, added_bbox images = image.unsqueeze(0) image_start_index_list = [2] return images, len(images), image_start_index_list, input_ids, attention_mask, added_bbox def preprocess_caption(sample, image_processor, tokenizer, image_embedding_size, max_length, single=False): image, caption = sample caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption image = preprocess_image(image, image_processor=image_processor) input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=single) return image, input_ids, attention_mask def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False): input_shards = args.pile_shards assert input_shards is not None resampled = getattr(args, "dataset_resampled", False) assert resampled, "turn on dataset_resampled to allow infinite stream of samples" # create a shared epoch store to sync epoch to dataloader worker proc shared_epoch = SharedEpoch(epoch=epoch) preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer, max_length=args.max_length) pipeline = [ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), tarfile_to_samples_nothrow, wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), wds.to_tuple("txt", handler=log_and_continue), wds.map_tuple( preprocess_text_fn, handler=log_and_continue ), ] # with_epoch(sys.maxsize) will give us an infinite sample stream dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1] dataset = ConcatDataset(iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id) def text_collate_fn(items): try: input_ids = torch.cat([x[0].unsqueeze(0) for x in items], dim=0) attention_mask = torch.cat([x[1].unsqueeze(0) for x in items], dim=0) return input_ids, attention_mask except: return None, None dataloader = wds.WebLoader( dataset, batch_size=args.batch_size_pile, shuffle=False, num_workers=args.workers, persistent_workers=False, collate_fn=text_collate_fn, ) return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) # FIXME: # modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433 # combine_tensors=True to combine_tensors=False def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False): input_shards = args.laion_shards assert input_shards is not None resampled = getattr(args, "dataset_resampled", False) assert resampled, "turn on dataset_resampled to allow infinite stream of samples" # create a shared epoch store to sync epoch to dataloader worker proc shared_epoch = SharedEpoch(epoch=epoch) generator = caption_grounder( config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth", cpu_only=True, # box_threshold=0.5, text_threshold=0.3, ) preprocess_ground_caption_fn = functools.partial( preprocess_ground_caption, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=args.vis_embed_size, single=args.single, generator=generator, prob_ground=args.prob_ground, use_format_v2=args.use_format_v2, add_visual_token=args.add_visual_token, max_length=args.max_length, args=args, ) pipeline = [ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), tarfile_to_samples_nothrow, wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), wds.select(filter_no_caption_or_no_image), wds.decode("pilrgb", partial=True, handler=log_and_continue), wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", "relation_data", handler=log_and_continue), wds.map( preprocess_ground_caption_fn, handler=log_and_continue ), ] dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) # for sample in dataset: # print(tokenizer.decode(sample[1][0]).replace("", "")) # DEBUG # dataset = wds.DataPipeline(*pipeline) # from tqdm import tqdm # for sample in tqdm(dataset): # nn = 0 # for x in sample[1][0]: # if x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]: # nn += 1 # if x == tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]: # nn -= 1 # if nn not in [0, 1]: # print(tokenizer.decode(sample[1][0]).replace("", "")) # import pdb; pdb.set_trace() # if nn != 0: # print(tokenizer.decode(sample[1][0]).replace("", "")) # import pdb; pdb.set_trace() # from groundingdino.demo.inference_on_laion import OBJ_LENGTHS # # import pdb; pdb.set_trace() # print(sum(OBJ_LENGTHS) / len(OBJ_LENGTHS)) # exit() # DEBUG media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] box_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] visual_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] dataset = ConcatDataset( iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id, pad_id=tokenizer.pad_token_id, media_id=media_token_id, endofmedia_id=endofmedia_token_id, box_id=box_id, visual_id=visual_id, image_embedding_size=args.vis_embed_size, single=args.single, ) def image_collate_fn(items): images = torch.cat([x[0] for x in items], dim=0) image_nums = [x[1] for x in items] image_start_index_list = [x[2] for x in items] input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0) attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0) added_bbox_list = [x[5] for x in items] expand_list = added_bbox_list[0] for x in added_bbox_list[1:]: expand_list.extend(x) relations_list = [x[6] for x in items] return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list, relations_list dataloader = wds.WebLoader( dataset, batch_size=args.batch_size_laion, shuffle=False, num_workers=args.workers, persistent_workers=False, collate_fn=image_collate_fn, ) round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size_laion * args.world_size num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size) dataloader.num_batches = num_batches return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_image_text_pair_dataset(args, image_processor, tokenizer, epoch=0, floor=False): input_shards = args.laion_shards assert input_shards is not None resampled = getattr(args, "dataset_resampled", False) assert resampled, "turn on dataset_resampled to allow infinite stream of samples" # create a shared epoch store to sync epoch to dataloader worker proc shared_epoch = SharedEpoch(epoch=epoch) preprocess_caption_fn = functools.partial( preprocess_caption, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=args.vis_embed_size, single=args.single, max_length=args.max_length, ) pipeline = [ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), tarfile_to_samples_nothrow, wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), wds.select(filter_no_caption_or_no_image), wds.decode("pilrgb", handler=log_and_continue), wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue), wds.map( preprocess_caption_fn, handler=log_and_continue ), ] dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] dataset = ConcatDataset( iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id, pad_id=tokenizer.pad_token_id, media_id=media_token_id, endofmedia_id=endofmedia_token_id, image_embedding_size=args.vis_embed_size, single=args.single, ) def image_collate_fn(items): images = torch.cat([x[0] for x in items], dim=0) image_nums = [x[1] for x in items] image_start_index_list = [x[2] for x in items] input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0) attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0) return images, image_nums, image_start_index_list, input_ids, attention_mask dataloader = wds.WebLoader( dataset, batch_size=args.batch_size_laion, shuffle=False, num_workers=args.workers, persistent_workers=False, collate_fn=image_collate_fn, ) round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size_laion * args.world_size num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size) dataloader.num_batches = num_batches return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_instruct_dataset(args, image_processor, tokenizer, epoch=0, floor=False): input_shards = args.laion_shards assert input_shards is not None resampled = getattr(args, "dataset_resampled", False) assert resampled, "turn on dataset_resampled to allow infinite stream of samples" # create a shared epoch store to sync epoch to dataloader worker proc shared_epoch = SharedEpoch(epoch=epoch) preprocess_instruct_fn = functools.partial( preprocess_instruct, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=args.vis_embed_size, max_length=args.max_length, ) pipeline = [ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch), tarfile_to_samples_nothrow, wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), wds.decode(partial=True), wds.to_tuple("image_path.txt", "dataset.txt", "data.pyd", handler=log_and_continue), wds.map( preprocess_instruct_fn, handler=log_and_continue ), ] dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize) def image_collate_fn(items): images = torch.cat([x[0] for x in items], dim=0) image_nums = [x[1] for x in items] image_start_index_list = [x[2] for x in items] input_ids = torch.cat([x[3] for x in items], dim=0) attention_mask = torch.cat([x[4] for x in items], dim=0) added_bbox_list = [x[5] for x in items] expand_list = added_bbox_list[0] for x in added_bbox_list[1:]: expand_list.extend(x) return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list dataloader = wds.WebLoader( dataset, batch_size=args.batch_size_laion, shuffle=False, num_workers=args.workers, persistent_workers=False, collate_fn=image_collate_fn, ) round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size_laion * args.world_size num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size) dataloader.num_batches = num_batches return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_dataset_fn(dataset_type): if dataset_type == "mmc4": raise NotImplementedError elif dataset_type == "pile": return get_pile_dataset elif dataset_type == "ground_image_text": return get_ground_laion_dataset elif dataset_type == "image_text": return get_image_text_pair_dataset elif dataset_type == "vqav2": raise NotImplementedError elif dataset_type == "instruct": return get_instruct_dataset else: raise ValueError(f"Unsupported dataset type: {dataset_type}") def get_data(args, image_processor, tokenizer, dataset_type, epoch=0): return get_dataset_fn(dataset_type)( args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer )