Spaces:
Runtime error
Runtime error
import argparse | |
import base64 | |
import json | |
import os | |
import tarfile | |
import uuid | |
import zipfile | |
import time | |
import braceexpand | |
import webdataset as wds | |
from tqdm import tqdm | |
from tqdm.contrib.concurrent import process_map | |
arg_parser = argparse.ArgumentParser() | |
arg_parser.add_argument("--output_dir", type=str) | |
arg_parser.add_argument( | |
"--image_shards", | |
type=str, | |
help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar", | |
) | |
arg_parser.add_argument( | |
"--doc_shards", | |
type=str, | |
help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip", | |
) | |
arg_parser.add_argument( | |
"--thread", | |
type=int, | |
default=128, | |
) | |
args = arg_parser.parse_args() | |
def get_txt_to_filename_dict(image_shards, disable_tqdm=False): | |
txt_to_filename_dict = {} | |
dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json") | |
for data in tqdm(dataset, disable=disable_tqdm): | |
txt = data[0].split(".")[0] | |
txt_to_filename_dict[txt] = data[1]['key'] | |
return txt_to_filename_dict | |
def single_thread(args): | |
i = args["i"] | |
output_dir = args["output_dir"] | |
doc_shards = args["doc_shards"] | |
image_shards = args["image_shards"] | |
if i == 0: | |
tqdm.write(f"output_dir: {output_dir}") | |
tqdm.write(f"doc_shards: {doc_shards[:5]}") | |
tqdm.write(f"image_shards: {image_shards[:5]}") | |
with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink: | |
sink.verbose = False | |
for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)): | |
# txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0)) | |
# image_tar = tarfile.open(image_shard) | |
# Open the ZIP archive and extract the JSON file | |
with zipfile.ZipFile(doc_shard, "r") as zip_file: | |
# Assumes the JSON file is the first file in the archive | |
json_filename = zip_file.namelist()[0] | |
with zip_file.open(json_filename, "r") as json_file: | |
pbar = tqdm(json_file, disable=True) | |
total_num = 0 | |
exist_num = 0 | |
for sample_data in pbar: | |
# get image names from json | |
sample_data = json.loads(sample_data) | |
image_info = sample_data["image_info"] | |
image_names = [image["image_name"] for image in image_info] | |
# Add each image to the tar file | |
for img_idx, image_name in enumerate(image_names): | |
total_num += 1 | |
try: | |
image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg") | |
# convert to base64 | |
image_bytes = image.read() | |
image_base64 = base64.b64encode(image_bytes).decode("utf-8") | |
exist_num += 1 | |
except: | |
tqdm.write(f"{image_name.split('.')[0]}") | |
image_base64 = "null" | |
sample_data["image_info"][img_idx][ | |
"image_base64" | |
] = image_base64 | |
key_str = uuid.uuid4().hex | |
sink.write({"__key__": key_str, "json": sample_data}) | |
pbar.set_description(f"{exist_num/total_num:.2f}") | |
# image_tar.close() | |
def main(): | |
timestamp = int(time.time()) | |
os.makedirs(args.output_dir, exist_ok=True) | |
os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True) | |
tasks = [] | |
for i in range(args.thread): | |
thread_dir = os.path.join(args.output_dir, str(timestamp), str(i)) | |
os.makedirs(thread_dir, exist_ok=True) | |
tasks.append({ | |
"i": i, | |
"output_dir": thread_dir, | |
"doc_shards": [], | |
"image_shards": [], | |
}) | |
doc_shards = list(braceexpand.braceexpand(args.doc_shards)) | |
image_shards = list(braceexpand.braceexpand(args.image_shards)) | |
assert len(doc_shards) == len( | |
image_shards | |
), "Each doc shards must have a corresponding image shard" | |
for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)): | |
tasks[i % args.thread]["doc_shards"].append(doc_shard) | |
tasks[i % args.thread]["image_shards"].append(image_shard) | |
# assert len(tasks) == args.thread | |
# process_map(single_thread, tasks, max_workers=args.thread, disable=True) | |
single_thread(tasks[0]) | |
if __name__ == "__main__": | |
main() | |