Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import ast | |
import copy | |
from curses import meta | |
from email.mime import image | |
import json | |
import logging | |
import math | |
import os | |
import random | |
import sys | |
import time | |
import io | |
import itertools | |
import braceexpand | |
from dataclasses import dataclass | |
from multiprocessing import Value | |
import pyarrow as pa | |
import numpy as np | |
import pandas as pd | |
import functools | |
import torch | |
import torchvision.datasets as datasets | |
import torchvision.transforms.functional as TF | |
import torch.distributed as dist | |
import webdataset as wds | |
from PIL import Image | |
from torchvision.transforms import InterpolationMode | |
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info | |
from torch.utils.data.distributed import DistributedSampler, Sampler | |
from webdataset.filters import _shuffle | |
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample | |
from open_clip import transform | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
try: | |
from petrel_client.client import Client | |
except ImportError as E: | |
"petrel_client.client cannot be imported" | |
pass | |
def pil_loader(img_str): | |
buff = io.BytesIO(img_str) | |
return Image.open(buff).convert("RGB") | |
def _get_global_gloo_group(): | |
""" | |
Return a process group based on gloo backend, containing all the ranks | |
The result is cached. | |
""" | |
if dist.get_backend() == "nccl": | |
return dist.new_group(backend="gloo") | |
else: | |
return dist.group.WORLD | |
def all_gather(data, group=None): | |
""" | |
Run all_gather on arbitrary picklable data (not necessarily tensors). | |
Args: | |
data: any picklable object | |
group: a torch process group. By default, will use a group which | |
contains all ranks on gloo backend. | |
Returns: | |
list[data]: list of data gathered from each rank | |
""" | |
if dist.get_world_size() == 1: | |
return [data] | |
if group is None: | |
group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage. | |
world_size = dist.get_world_size(group) | |
if world_size == 1: | |
return [data] | |
output = [None for _ in range(world_size)] | |
dist.all_gather_object(output, data, group=group) | |
return output | |
def shared_random_seed(): | |
""" | |
Returns: | |
int: a random number that is the same across all workers. | |
If workers need a shared RNG, they can use this shared seed to | |
create one. | |
All workers must call this function, otherwise it will deadlock. | |
""" | |
ints = np.random.randint(2**31) | |
all_ints = all_gather(ints) | |
return all_ints[0] | |
class TrainingSampler(Sampler): | |
""" | |
In training, we only care about the "infinite stream" of training data. | |
So this sampler produces an infinite stream of indices and | |
all workers cooperate to correctly shuffle the indices and sample different indices. | |
The samplers in each worker effectively produces `indices[worker_id::num_workers]` | |
where `indices` is an infinite stream of indices consisting of | |
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) | |
or `range(size) + range(size) + ...` (if shuffle is False) | |
""" | |
def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, seed = None): | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
self.dataset = dataset | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) -1 | |
self.total_size = len(dataset) | |
self.shuffle = shuffle | |
# self.dataset_repeat = dataset_repeat | |
if seed is None: | |
seed = shared_random_seed() | |
self.seed = int(seed) | |
def __len__(self): | |
return self.num_samples | |
def __iter__(self): | |
start = self.rank | |
yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas) | |
def _infinite_indices(self): | |
g = torch.Generator() | |
g.manual_seed(self.seed) | |
while True: | |
if self.shuffle: | |
yield from torch.randperm(self.total_size, generator=g).tolist() | |
else: | |
yield from torch.arange(self.total_size).tolist() | |
class TCSLoader(object): | |
def __init__(self, time_limit=3): | |
conf_path = os.environ.get('CEPH_CONFIG', './petreloss.config') | |
self.client = Client(conf_path) | |
self.time_limit = time_limit | |
def __call__(self, fn): | |
try: | |
img_value_str = self.client.get(fn) | |
img = pil_loader(img_value_str) | |
return img | |
except Exception as e: | |
print('Read image failed ({})'.format(fn)) | |
raise e | |
class CsvDataset(Dataset): | |
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): | |
logging.debug(f'Loading csv data from {input_filename}.') | |
df = pd.read_csv(input_filename, sep=sep) | |
self.images = df[img_key].tolist() | |
self.captions = df[caption_key].tolist() | |
self.transforms = transforms | |
logging.debug('Done loading data.') | |
self.tokenize = tokenizer | |
def __len__(self): | |
return len(self.captions) | |
def __getitem__(self, idx): | |
images = self.transforms(Image.open(str(self.images[idx]))) | |
texts = self.tokenize([str(self.captions[idx])])[0] | |
return images, texts | |
class SharedEpoch: | |
def __init__(self, epoch: int = 0): | |
self.shared_epoch = Value('i', epoch) | |
def set_value(self, epoch): | |
self.shared_epoch.value = epoch | |
def get_value(self): | |
return self.shared_epoch.value | |
class DataInfo: | |
dataloader: DataLoader | |
data_type: str | |
sampler: DistributedSampler = None | |
shared_epoch: SharedEpoch = None | |
def set_epoch(self, epoch): | |
if self.shared_epoch is not None: | |
self.shared_epoch.set_value(epoch) | |
if self.sampler is not None and isinstance(self.sampler, DistributedSampler): | |
self.sampler.set_epoch(epoch) | |
def expand_urls(urls, weights=None): | |
if weights is None: | |
expanded_urls = wds.shardlists.expand_urls(urls) | |
return expanded_urls, None | |
if isinstance(urls, str): | |
urllist = urls.split("::") | |
weights = weights.split('::') | |
assert len(weights) == len(urllist), f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." | |
weights = [float(weight) for weight in weights] | |
all_urls, all_weights = [], [] | |
for url, weight in zip(urllist, weights): | |
expanded_url = list(braceexpand.braceexpand(url)) | |
expanded_weights = [weight for _ in expanded_url] | |
all_urls.extend(expanded_url) | |
all_weights.extend(expanded_weights) | |
return all_urls, all_weights | |
else: | |
all_urls = list(urls) | |
return all_urls, weights | |
def get_dataset_size(shards): | |
shards_list, _ = expand_urls(shards) | |
dir_path = os.path.dirname(shards_list[0]) | |
sizes_filename = os.path.join(dir_path, 'sizes.json') | |
len_filename = os.path.join(dir_path, '__len__') | |
if os.path.exists(sizes_filename): | |
sizes = json.load(open(sizes_filename, 'r')) | |
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) | |
elif os.path.exists(len_filename): | |
# FIXME this used to be eval(open(...)) but that seemed rather unsafe | |
total_size = ast.literal_eval(open(len_filename, 'r').read()) | |
else: | |
total_size = None # num samples undefined | |
# some common dataset sizes (at time of authors last download) | |
# CC3M (train): 2905954 | |
# CC12M: 10968539 | |
# LAION-400M: 407332084 | |
# LAION-2B (english): 2170337258 | |
num_shards = len(shards_list) | |
return total_size, num_shards | |
def get_imagenet(args, preprocess_fns, split): | |
assert split in ["train", "val", "v2"] | |
is_train = split == "train" | |
preprocess_train, preprocess_val = preprocess_fns | |
if split == "v2": | |
from imagenetv2_pytorch import ImageNetV2Dataset | |
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) | |
else: | |
if is_train: | |
data_path = args.imagenet_train | |
preprocess_fn = preprocess_train | |
else: | |
data_path = args.imagenet_val | |
preprocess_fn = preprocess_val | |
assert data_path | |
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) | |
if is_train: | |
idxs = np.zeros(len(dataset.targets)) | |
target_array = np.array(dataset.targets) | |
k = 50 | |
for c in range(1000): | |
m = target_array == c | |
n = len(idxs[m]) | |
arr = np.zeros(n) | |
arr[:k] = 1 | |
np.random.shuffle(arr) | |
idxs[m] = arr | |
idxs = idxs.astype('int') | |
sampler = SubsetRandomSampler(np.where(idxs)[0]) | |
else: | |
sampler = None | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
num_workers=args.workers, | |
sampler=sampler, | |
) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='classification') | |
def count_samples(dataloader): | |
os.environ["WDS_EPOCH"] = "0" | |
n_elements, n_batches = 0, 0 | |
for images, texts in dataloader: | |
n_batches += 1 | |
n_elements += len(images) | |
assert len(images) == len(texts) | |
return n_elements, n_batches | |
def filter_no_caption_or_no_image(sample): | |
has_caption = ('txt' in sample) | |
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) | |
return has_caption and has_image | |
def log_and_continue(exn): | |
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
return True | |
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): | |
"""Return function over iterator that groups key, value pairs into samples. | |
:param keys: function that splits the key into key and extension (base_plus_ext) | |
:param lcase: convert suffixes to lower case (Default value = True) | |
""" | |
current_sample = None | |
for filesample in data: | |
assert isinstance(filesample, dict) | |
fname, value = filesample["fname"], filesample["data"] | |
prefix, suffix = keys(fname) | |
if prefix is None: | |
continue | |
if lcase: | |
suffix = suffix.lower() | |
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: | |
if valid_sample(current_sample): | |
yield current_sample | |
current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
if suffixes is None or suffix in suffixes: | |
current_sample[suffix] = value | |
if valid_sample(current_sample): | |
yield current_sample | |
def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander(streams, handler=handler) | |
samples = group_by_keys_nothrow(files, handler=handler) | |
return samples | |
def pytorch_worker_seed(increment=0): | |
"""get dataloader worker seed from pytorch""" | |
worker_info = get_worker_info() | |
if worker_info is not None: | |
# favour using the seed already created for pytorch dataloader workers if it exists | |
seed = worker_info.seed | |
if increment: | |
# space out seed increments so they can't overlap across workers in different iterations | |
seed += increment * max(1, worker_info.num_workers) | |
return seed | |
# fallback to wds rank based seed | |
return wds.utils.pytorch_worker_seed() | |
_SHARD_SHUFFLE_SIZE = 2000 | |
_SHARD_SHUFFLE_INITIAL = 500 | |
_SAMPLE_SHUFFLE_SIZE = 5000 | |
_SAMPLE_SHUFFLE_INITIAL = 1000 | |
class detshuffle2(wds.PipelineStage): | |
def __init__( | |
self, | |
bufsize=1000, | |
initial=100, | |
seed=0, | |
epoch=-1, | |
): | |
self.bufsize = bufsize | |
self.initial = initial | |
self.seed = seed | |
self.epoch = epoch | |
def run(self, src): | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
rng = random.Random() | |
if self.seed < 0: | |
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers | |
seed = pytorch_worker_seed(epoch) | |
else: | |
# This seed to be deterministic AND the same across all nodes/workers in each epoch | |
seed = self.seed + epoch | |
rng.seed(seed) | |
return _shuffle(src, self.bufsize, self.initial, rng) | |
class ResampledShards2(IterableDataset): | |
"""An iterable dataset yielding a list of urls.""" | |
def __init__( | |
self, | |
urls, | |
weights=None, | |
nshards=sys.maxsize, | |
worker_seed=None, | |
deterministic=False, | |
epoch=-1, | |
): | |
"""Sample shards from the shard list with replacement. | |
:param urls: a list of URLs as a Python list or brace notation string | |
""" | |
super().__init__() | |
urls, weights = expand_urls(urls, weights) | |
self.urls = urls | |
self.weights = weights | |
if self.weights is not None: | |
assert len(self.urls) == len(self.weights), f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." | |
assert isinstance(self.urls[0], str) | |
self.nshards = nshards | |
self.rng = random.Random() | |
self.worker_seed = worker_seed | |
self.deterministic = deterministic | |
self.epoch = epoch | |
def __iter__(self): | |
"""Return an iterator over the shards.""" | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
if self.deterministic: | |
# reset seed w/ epoch if deterministic | |
if self.worker_seed is None: | |
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id | |
seed = pytorch_worker_seed(epoch) | |
else: | |
seed = self.worker_seed() + epoch | |
self.rng.seed(seed) | |
for _ in range(self.nshards): | |
if self.weights is None: | |
yield dict(url=self.rng.choice(self.urls)) | |
else: | |
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) | |
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.to_tuple("image", "text"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch, data_type='image-text') | |
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
input_filename = args.train_data if is_train else args.val_data | |
assert input_filename | |
dataset = CsvDataset( | |
input_filename, | |
preprocess_fn, | |
img_key=args.csv_img_key, | |
caption_key=args.csv_caption_key, | |
sep=args.csv_separator, | |
tokenizer=tokenizer | |
) | |
num_samples = len(dataset) | |
sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and not args.distributed and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='image-text') | |
class SyntheticDataset(Dataset): | |
def __init__(self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None): | |
self.transform = transform | |
self.image_size = image_size | |
self.caption = caption | |
self.image = Image.new('RGB', image_size) | |
self.dataset_size = dataset_size | |
self.preprocess_txt = lambda text: tokenizer(text)[0] | |
def __len__(self): | |
return self.dataset_size | |
def __getitem__(self, idx): | |
if self.transform is not None: | |
image = self.transform(self.image) | |
return image, self.preprocess_txt(self.caption) | |
def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
image_size = preprocess_fn.transforms[0].size | |
dataset = SyntheticDataset( | |
transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) | |
num_samples = len(dataset) | |
sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and not args.distributed and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='image-text') | |
class PreferenceDataset(Dataset): | |
def __init__(self, meta_file, image_folder, transforms, tokenizer, extra_data=(None, None)): | |
extra_meta, extra_folder = extra_data | |
self.transforms = transforms | |
self.tokenizer = tokenizer | |
self.open_image = Image.open | |
if image_folder.startswith('s3://'): | |
loader = TCSLoader() | |
self.open_image = loader | |
if meta_file is not None: | |
with open(meta_file, 'r') as f: | |
self.table = pa.Table.from_pylist(json.load(f)) | |
self.image_folder = image_folder | |
else: | |
# self.captions = pa.array() | |
self.table = [] | |
if extra_meta: | |
with open(extra_meta, 'r') as f: | |
meta = json.load(f) | |
self.files = [t['files'] for t in meta] | |
self.extra_captions = [t['caption'] for t in meta] | |
self.extra_label = [t['human_preference'] for t in meta] | |
self.extra_image_folder = extra_folder | |
else: | |
self.extra_captions = [] | |
def __len__(self): | |
return len(self.table) + len(self.extra_captions) | |
def __getitem__(self, idx): | |
try: | |
if idx < len(self.table): | |
images = [self.transforms(self.open_image(os.path.join(self.image_folder, file_names))) for file_names in self.table.column('file_path')[idx].as_py()] | |
if not len(set([i.size() for i in images])) == 1: | |
return self.__getitem__((idx + 1) % len(self)) | |
label = self.table.column('pap_pref')[idx].as_py() | |
caption = self.tokenizer(self.table.column('prompt')[idx].as_py()) | |
else: | |
idx = idx - len(self.captions) | |
images = [self.transforms(self.open_image(os.path.join(self.extra_image_folder, f))) for f in self.files[idx]] | |
label = self.extra_label[idx] | |
caption = self.tokenizer(self.extra_captions[idx]) | |
if not len(set([i.size() for i in images])) == 1: | |
return self.__getitem__((idx + 1) % len(self)) | |
else: | |
return images, label, caption | |
except: | |
return self.__getitem__((idx + 1) % len(self)) | |
class HPDDataset(Dataset): | |
def __init__(self, meta_file, image_folder, transforms, tokenizer, is_train=True): | |
self.transforms = transforms | |
self.tokenizer = tokenizer | |
self.open_image = Image.open | |
self.is_train = is_train | |
if image_folder.startswith('s3://'): | |
loader = TCSLoader() | |
self.open_image = loader | |
if meta_file is not None: | |
with open(meta_file, 'r') as f: | |
self.table = pa.Table.from_pylist(json.load(f)) | |
self.image_folder = image_folder | |
else: | |
self.table = [] | |
def __len__(self): | |
return len(self.table) | |
def __getitem__(self, idx): | |
try: | |
if self.is_train: | |
images = [self.transforms(self.open_image(os.path.join(self.image_folder, file_names))) for file_names in self.table.column('file_path')[idx].as_py()] | |
if not len(set([i.size() for i in images])) == 1: | |
return self.__getitem__((idx + 1) % len(self)) | |
label = self.table.column('human_preference')[idx].as_py() | |
caption = self.tokenizer(self.table.column('prompt')[idx].as_py()) | |
# num_per_prompt = self.table.column('num_per_prompt')[idx].as_py() | |
return images, label, caption | |
else: | |
images = [self.transforms(self.open_image(os.path.join(self.image_folder, file_names))) for file_names in self.table.column('file_path')[idx].as_py()] | |
if not len(set([i.size() for i in images])) == 1: | |
return self.__getitem__((idx + 1) % len(self)) | |
label = self.table.column('human_preference')[idx].as_py() | |
caption = self.tokenizer(self.table.column('prompt')[idx].as_py()) | |
return images, label, caption | |
except: | |
return self.__getitem__((idx + 1) % len(self)) | |
class RatingDataset(Dataset): | |
def __init__(self, meta_file, image_folder, transforms): | |
self.transforms = transforms | |
self.image_folder = image_folder | |
self.open_image = Image.open | |
self.max_size = 224 | |
if image_folder.startswith('s3://'): | |
loader = TCSLoader() | |
self.open_image = loader | |
with open(meta_file, 'r') as f: | |
self.table = pa.Table.from_pylist(json.load(f)) | |
def __len__(self): | |
return len(self.table) | |
def __getitem__(self, idx): | |
try: | |
images = self.transforms(self.open_image(os.path.join(self.image_folder, self.table.column('path')[idx].as_py()))) | |
img_weight, img_height = images.shape[1:] | |
if img_weight != self.max_size or img_height != self.max_size: | |
return self.__getitem__((idx + 10) % len(self)) | |
label = self.table.column('rating')[idx].as_py() | |
return images, label | |
except: | |
return self.__getitem__((idx + 1) % len(self)) | |
class RankingDataset(Dataset): | |
def __init__(self, meta_file, image_folder, transforms, tokenizer): | |
self.transforms = transforms | |
self.image_folder = image_folder | |
self.open_image = Image.open | |
if image_folder.startswith('s3://'): | |
loader = TCSLoader() | |
self.open_image = loader | |
self.tokenizer = tokenizer | |
with open(meta_file, 'r') as f: | |
self.table = pa.Table.from_pylist(json.load(f)) | |
def __len__(self): | |
return len(self.table) | |
def __getitem__(self, idx): | |
try: | |
images = [self.transforms(self.open_image(os.path.join(self.image_folder, file_names))) for file_names in self.table.column('image_path')[idx].as_py()] | |
label = self.table.column('rank')[idx].as_py() | |
caption = self.tokenizer(self.table.column('prompt')[idx].as_py()) | |
return images, label, caption | |
except: | |
return self.__getitem__((idx + 1) % len(self)) | |
class RegionDataset(Dataset): | |
def __init__(self, meta_file, image_folder, transforms): | |
self.transforms = transforms | |
self.image_folder = image_folder | |
self.open_image = Image.open | |
with open(meta_file,'r') as f: | |
self.table = pa.Table.from_pylist(json.load(f)) | |
def __len__(self): | |
return len(self.table) | |
def __getitem__(self, idx): | |
try: | |
img = self.open_image(os.path.join(self.image_folder, self.table.column('image_path')[idx].as_py())) | |
mask = self.open_image(os.path.join(self.image_folder, self.table.column('mask_path')[idx].as_py())) | |
img.putalpha(mask) | |
masked_image = self.transforms(img) | |
image = masked_image[:3] | |
mask = masked_image[3] | |
return image, mask | |
except: | |
return self.__getitem__((idx + 1) % len(self)) | |
class ImageRewardDataset(Dataset): | |
def __init__(self, meta_file, image_folder,transforms, tokenizer): | |
self.transforms = transforms | |
self.image_folder = image_folder | |
self.open_image = Image.open | |
self.tokenizer = tokenizer | |
with open(meta_file, 'r') as f: | |
self.table = pa.Table.from_pylist(json.load(f)) | |
def __len__(self): | |
return len(self.table) | |
def __getitem__(self, idx): | |
images = [self.transforms(self.open_image(os.path.join(self.image_folder, file_names))) for file_names in self.table.column('generations')[idx].as_py()] | |
label = self.table.column('ranking')[idx].as_py() | |
caption = self.tokenizer(self.table.column('prompt')[idx].as_py()) | |
return images, label, caption | |
def set_env_vars(something): | |
os.environ['http_proxy'] = '' | |
os.environ['https_proxy'] = '' | |
def collate_rating(batch): | |
images = [sample[0] for sample in batch] | |
labels = torch.tensor([sample[1] for sample in batch]) | |
images = torch.stack(images) | |
return images, labels | |
def get_rating_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
# only training data | |
assert is_train | |
dataset = RatingDataset(meta_file=args.train_data, | |
image_folder=args.train_folder, | |
transforms=preprocess_fn) | |
num_samples = len(dataset) | |
sampler = TrainingSampler(dataset) if args.distributed else None | |
shuffle = is_train and not args.distributed | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
collate_fn=collate_rating, | |
worker_init_fn=set_env_vars, | |
persistent_workers=True, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='rating') | |
def collate_pref(batch): | |
images = [torch.stack(sample[0]) for sample in batch] | |
num_images = torch.tensor([g.size(0) for g in images]) | |
labels = torch.tensor([sample[1] for sample in batch]) | |
captions = torch.cat([sample[2] for sample in batch]) | |
images = torch.cat(images) | |
return images, num_images, labels, captions | |
def get_preference_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None, extra_val=False): | |
if is_train: | |
extra_data = (args.extra_train_data, args.extra_train_folder) | |
dataset = PreferenceDataset(meta_file=args.train_data if is_train else args.val_data, | |
image_folder=args.train_folder if is_train else args.val_folder, | |
transforms=preprocess_fn, tokenizer=tokenizer, extra_data=extra_data) | |
else: | |
if extra_val: | |
dataset = PreferenceDataset(meta_file=None, | |
image_folder=None, | |
transforms=preprocess_fn, tokenizer=tokenizer, extra_data=(args.extra_val_data, args.extra_val_folder)) | |
else: | |
dataset = PreferenceDataset(meta_file=args.val_data, | |
image_folder=args.val_folder, | |
transforms=preprocess_fn, tokenizer=tokenizer) | |
num_samples = len(dataset) | |
sampler = TrainingSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and not args.distributed and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
collate_fn=collate_pref, | |
worker_init_fn=set_env_vars, | |
persistent_workers=True, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='preference') | |
def collate_HPD(batch): | |
image_1 = torch.stack([sample[0][0] for sample in batch]) | |
image_2 = torch.stack([sample[0][1] for sample in batch]) | |
label_1 = torch.tensor([sample[1][0] for sample in batch]) | |
label_2 = torch.tensor([sample[1][1] for sample in batch]) | |
labels = torch.cat([label_1, label_2], dim=0) | |
captions = torch.cat([sample[2] for sample in batch]) | |
images = torch.cat([image_1, image_2]) | |
return images, labels, captions | |
def get_HPD_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
dataset = HPDDataset(meta_file=args.train_data if is_train else args.val_data, | |
image_folder=args.train_folder if is_train else args.val_folder, | |
transforms=preprocess_fn, tokenizer=tokenizer, is_train=is_train) | |
num_samples = len(dataset) | |
sampler = TrainingSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and not args.distributed and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
collate_fn=collate_HPD if is_train else collate_pref, | |
worker_init_fn=set_env_vars, | |
persistent_workers=True, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='HPD' if is_train else 'preference') | |
def get_ranking_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
if is_train: | |
dataset = RankingDataset(meta_file=args.train_data, | |
image_folder=args.train_folder, transforms=preprocess_fn, tokenizer=tokenizer) | |
else: | |
dataset = RankingDataset(meta_file=args.val_data, | |
image_folder=args.val_folder, transforms=preprocess_fn, tokenizer=tokenizer) | |
num_samples = len(dataset) | |
sampler = TrainingSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and not args.distributed and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
collate_fn=collate_rank, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='ranking') | |
def get_regional_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
if is_train: | |
dataset = RegionDataset( | |
meta_file=args.train_data, | |
image_folder=args.train_folder, | |
transforms=preprocess_fn | |
) | |
else: | |
dataset = RegionDataset( | |
meta_file=args.val_data, | |
image_folder=args.val_folder, | |
transforms=preprocess_fn | |
) | |
num_samples = len(dataset) | |
sampler = TrainingSampler(dataset) if args.distributed else None | |
shuffle = is_train and not args.distributed | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
worker_init_fn=set_env_vars, | |
persistent_workers=True, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='regional') | |
def collate_rank(batch): | |
images = [torch.stack(sample[0]) for sample in batch] | |
num_images = torch.tensor([g.size(0) for g in images]) | |
labels = [torch.tensor(sample[1]) for sample in batch] | |
captions = torch.cat([sample[2] for sample in batch]) | |
images = torch.cat(images) | |
labels = torch.cat(labels) | |
return images, num_images, labels, captions | |
def get_imagereward_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
#only support evaluation | |
if not is_train: | |
dataset = ImageRewardDataset( | |
meta_file=args.val_data, | |
image_folder = args.val_folder, | |
transforms=preprocess_fn, | |
tokenizer=tokenizer | |
) | |
num_samples = len(dataset) | |
sampler = TrainingSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and not args.distributed | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
worker_init_fn=set_env_vars, | |
collate_fn=collate_rank, | |
persistent_workers=True, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader=dataloader, sampler=sampler, data_type='ImageReward') | |
def get_dataset_fn(data_path, dataset_type): | |
if dataset_type == "webdataset": | |
return get_wds_dataset | |
elif dataset_type == "csv": | |
return get_csv_dataset | |
elif dataset_type == "synthetic": | |
return get_synthetic_dataset | |
elif dataset_type == "auto": | |
ext = data_path.split('.')[-1] | |
if ext in ['csv', 'tsv']: | |
return get_csv_dataset | |
elif ext in ['tar']: | |
return get_wds_dataset | |
else: | |
raise ValueError( | |
f"Tried to figure out dataset type, but failed for extension {ext}.") | |
elif dataset_type == "preference": | |
return get_preference_dataset | |
elif dataset_type == "rating": | |
return get_rating_dataset | |
elif dataset_type == 'ranking': | |
return get_ranking_dataset | |
elif dataset_type == 'regional': | |
return get_regional_dataset | |
elif dataset_type == 'ImageReward': | |
return get_imagereward_dataset | |
elif dataset_type == "HPD": | |
return get_HPD_dataset | |
else: | |
raise ValueError(f"Unsupported dataset type: {dataset_type}") | |
def get_data(args, preprocess_fns, epoch=0, tokenizer=None): | |
preprocess_train, preprocess_val = preprocess_fns | |
data = {} | |
if args.train_data or args.dataset_type == "synthetic": | |
assert len(args.train_data) == len(args.dataset_type) == len(args.batch_size) == len(args.workers) == len(args.train_folder) == len(args.train_data_sample_ratio) == len(args.ignore_in_train) | |
for train_data, dataset_type, batch_size, workers, train_folder, train_data_sample_ratio, ignore in zip(args.train_data, args.dataset_type, args.batch_size, args.workers, args.train_folder, args.train_data_sample_ratio, args.ignore_in_train): | |
if ignore: | |
continue | |
if 'train' not in data: | |
data['train'] = [] | |
new_args = copy.deepcopy(args) | |
new_args.train_data = train_data | |
new_args.dataset_type = dataset_type | |
new_args.batch_size = batch_size | |
new_args.workers = workers | |
new_args.train_folder = train_folder | |
new_args.train_data_sample_ratio = train_data_sample_ratio | |
dataset = get_dataset_fn(new_args.train_data, new_args.dataset_type)( | |
new_args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) | |
data['train'].append(dataset) | |
if args.val_data[0]: | |
assert len(args.val_data) == len(args.dataset_type) == len(args.batch_size) == len(args.workers) == len(args.val_folder) == len(args.ignore_in_val) | |
# data['val'] = [] | |
for val_data, dataset_type, batch_size, workers, val_folder ,ignore in zip(args.val_data, args.dataset_type, args.batch_size, args.workers, args.val_folder, args.ignore_in_val): | |
if ignore: | |
continue | |
if 'val' not in data: | |
data['val'] = [] | |
new_args = copy.deepcopy(args) | |
new_args.val_data = val_data | |
new_args.dataset_type = dataset_type | |
new_args.batch_size = batch_size | |
new_args.workers = workers | |
new_args.val_folder = val_folder | |
dataset = get_dataset_fn(new_args.val_data, new_args.dataset_type)( | |
new_args, preprocess_val, is_train=False, tokenizer=tokenizer) | |
data['val'].append(dataset) | |
if args.extra_val_data: | |
assert False | |
data["extra_val"] = get_dataset_fn(args.val_data, args.dataset_type)( | |
args, preprocess_val, is_train=False, tokenizer=tokenizer, extra_val=True) | |
if args.imagenet_val is not None: | |
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") | |
if args.imagenet_v2 is not None: | |
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") | |
return data | |