countgd / tools /grit2odvg.py
nikigoli's picture
Upload folder using huggingface_hub
a277bb8 verified
raw
history blame
3.86 kB
import jsonlines
from tqdm import tqdm
import random
import json
import os
from multiprocessing import Pool
from functools import partial
import emoji
import argparse
def clean_span(span):
span = span.rstrip()
span = span.replace('"', "'").replace('\"', "'").replace('β€œ', "'").replace('”', "'")
span = span.replace('β€˜', "'").replace('’', "'").replace('–', "β€”")
if span.endswith('/') or span.endswith('.'):
span = span[:-1]
return span
def check_caption(cap):
check_anno = cap["caption"].rstrip()[:-1]
if not str.isascii(check_anno):
return False
# "The view is better from here πŸ¦… (Chouf" wtf??
check_list = {"↙️", "-", ",", "Β ", "*", "/", "$", "[CLS]", "[SEP]", "?"}
for ch in check_list:
if ch in check_anno:
return False
if '.' in check_anno[:-1]:
return False
if emoji.emoji_count(check_anno):
print(check_anno)
return False
return True
def get_regions(nc, anno):
h = anno["height"]
w = anno["width"]
phrase = clean_span(anno["caption"][int(nc[0]):int(nc[1])])
bbox = [round(nc[2]*w,2), round(nc[3]*h,2), round(nc[4]*w,2), round(nc[5]*h,2)]
return {
"bbox": bbox,
"phrase": phrase
}
def prepare_list(file_name: str, random_samples):
with open(file_name, "r") as f:
metas = [line.strip() for line in f]
num_of_files = len(metas)
print(num_of_files)
metas = random.sample(metas, random_samples)
num_of_files = len(metas)
print("after sample:", num_of_files)
return metas, num_of_files
def process_item(file, args):
with open(os.path.join(args.root, file)) as f:
anno = json.load(f)
if not check_caption(anno):
return None
noun_chunks = anno['noun_chunks']
ref_exps = anno['ref_exps']
regions = []
random_num = random.random()
if random_num > 0.5:
for nc in noun_chunks:
region = get_regions(nc, anno)
if str.isascii(region["phrase"]):
regions.append(region)
else:
for re in ref_exps:
region = get_regions(re, anno)
if str.isascii(region["phrase"]):
regions.append(region)
if len(regions) < args.min_phrase:
return None
odvg_anno = {
"filename": f'{file.split(".")[0]}.jpg',
"height": anno["height"],
"width": anno["width"],
"grounding": {
"caption": clean_span(anno["caption"]),
"regions": regions
}
}
return odvg_anno
if __name__ == "__main__":
# jsons = "/share_data/mllm/kosmos-2/GRIT-20M/anno/14m_anno.list"
# root = "/share_data/mllm/kosmos-2/GRIT-20M/data"
# output_name = "./girt_14m_odvg.jsonl"
parser = argparse.ArgumentParser(description="GRIT2ODVG List.")
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--root", type=str, default="", help="Source image root")
parser.add_argument("--output_file", type=str, default="girt_14m_odvg.jsonl")
parser.add_argument("--random_samples", type=int, default=200000)
parser.add_argument("--chunk_or_ref", type=float, default=0.5)
parser.add_argument("--min_phrase", type=int, default=6)
parser.add_argument("--process_num", type=int, default=10, help="the number of processes")
args = parser.parse_args()
print(args)
metas, metas_len = prepare_list(args.input_file, args.random_samples)
odvg_anno = []
func = partial(process_item, args=args)
with Pool(processes=args.process_num) as pool:
for result in tqdm(pool.imap(func=func, iterable=metas), total=len(metas)):
odvg_anno.append(result)
odvg_anno = list(filter(None, odvg_anno))
with jsonlines.open(args.output_file, mode="w") as fwriter:
fwriter.write_all(odvg_anno)