Spaces:
Running
on
T4
Running
on
T4
import jsonlines | |
from tqdm import tqdm | |
import random | |
import json | |
import os | |
from multiprocessing import Pool | |
from functools import partial | |
import emoji | |
import argparse | |
def clean_span(span): | |
span = span.rstrip() | |
span = span.replace('"', "'").replace('\"', "'").replace('β', "'").replace('β', "'") | |
span = span.replace('β', "'").replace('β', "'").replace('β', "β") | |
if span.endswith('/') or span.endswith('.'): | |
span = span[:-1] | |
return span | |
def check_caption(cap): | |
check_anno = cap["caption"].rstrip()[:-1] | |
if not str.isascii(check_anno): | |
return False | |
# "The view is better from here π¦ (Chouf" wtf?? | |
check_list = {"βοΈ", "-", ",", "Β ", "*", "/", "$", "[CLS]", "[SEP]", "?"} | |
for ch in check_list: | |
if ch in check_anno: | |
return False | |
if '.' in check_anno[:-1]: | |
return False | |
if emoji.emoji_count(check_anno): | |
print(check_anno) | |
return False | |
return True | |
def get_regions(nc, anno): | |
h = anno["height"] | |
w = anno["width"] | |
phrase = clean_span(anno["caption"][int(nc[0]):int(nc[1])]) | |
bbox = [round(nc[2]*w,2), round(nc[3]*h,2), round(nc[4]*w,2), round(nc[5]*h,2)] | |
return { | |
"bbox": bbox, | |
"phrase": phrase | |
} | |
def prepare_list(file_name: str, random_samples): | |
with open(file_name, "r") as f: | |
metas = [line.strip() for line in f] | |
num_of_files = len(metas) | |
print(num_of_files) | |
metas = random.sample(metas, random_samples) | |
num_of_files = len(metas) | |
print("after sample:", num_of_files) | |
return metas, num_of_files | |
def process_item(file, args): | |
with open(os.path.join(args.root, file)) as f: | |
anno = json.load(f) | |
if not check_caption(anno): | |
return None | |
noun_chunks = anno['noun_chunks'] | |
ref_exps = anno['ref_exps'] | |
regions = [] | |
random_num = random.random() | |
if random_num > 0.5: | |
for nc in noun_chunks: | |
region = get_regions(nc, anno) | |
if str.isascii(region["phrase"]): | |
regions.append(region) | |
else: | |
for re in ref_exps: | |
region = get_regions(re, anno) | |
if str.isascii(region["phrase"]): | |
regions.append(region) | |
if len(regions) < args.min_phrase: | |
return None | |
odvg_anno = { | |
"filename": f'{file.split(".")[0]}.jpg', | |
"height": anno["height"], | |
"width": anno["width"], | |
"grounding": { | |
"caption": clean_span(anno["caption"]), | |
"regions": regions | |
} | |
} | |
return odvg_anno | |
if __name__ == "__main__": | |
# jsons = "/share_data/mllm/kosmos-2/GRIT-20M/anno/14m_anno.list" | |
# root = "/share_data/mllm/kosmos-2/GRIT-20M/data" | |
# output_name = "./girt_14m_odvg.jsonl" | |
parser = argparse.ArgumentParser(description="GRIT2ODVG List.") | |
parser.add_argument("--input_file", type=str, required=True) | |
parser.add_argument("--root", type=str, default="", help="Source image root") | |
parser.add_argument("--output_file", type=str, default="girt_14m_odvg.jsonl") | |
parser.add_argument("--random_samples", type=int, default=200000) | |
parser.add_argument("--chunk_or_ref", type=float, default=0.5) | |
parser.add_argument("--min_phrase", type=int, default=6) | |
parser.add_argument("--process_num", type=int, default=10, help="the number of processes") | |
args = parser.parse_args() | |
print(args) | |
metas, metas_len = prepare_list(args.input_file, args.random_samples) | |
odvg_anno = [] | |
func = partial(process_item, args=args) | |
with Pool(processes=args.process_num) as pool: | |
for result in tqdm(pool.imap(func=func, iterable=metas), total=len(metas)): | |
odvg_anno.append(result) | |
odvg_anno = list(filter(None, odvg_anno)) | |
with jsonlines.open(args.output_file, mode="w") as fwriter: | |
fwriter.write_all(odvg_anno) |