Spaces:
Running
on
T4
Running
on
T4
File size: 3,861 Bytes
a277bb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import jsonlines
from tqdm import tqdm
import random
import json
import os
from multiprocessing import Pool
from functools import partial
import emoji
import argparse
def clean_span(span):
span = span.rstrip()
span = span.replace('"', "'").replace('\"', "'").replace('β', "'").replace('β', "'")
span = span.replace('β', "'").replace('β', "'").replace('β', "β")
if span.endswith('/') or span.endswith('.'):
span = span[:-1]
return span
def check_caption(cap):
check_anno = cap["caption"].rstrip()[:-1]
if not str.isascii(check_anno):
return False
# "The view is better from here π¦
(Chouf" wtf??
check_list = {"βοΈ", "-", ",", "Β ", "*", "/", "$", "[CLS]", "[SEP]", "?"}
for ch in check_list:
if ch in check_anno:
return False
if '.' in check_anno[:-1]:
return False
if emoji.emoji_count(check_anno):
print(check_anno)
return False
return True
def get_regions(nc, anno):
h = anno["height"]
w = anno["width"]
phrase = clean_span(anno["caption"][int(nc[0]):int(nc[1])])
bbox = [round(nc[2]*w,2), round(nc[3]*h,2), round(nc[4]*w,2), round(nc[5]*h,2)]
return {
"bbox": bbox,
"phrase": phrase
}
def prepare_list(file_name: str, random_samples):
with open(file_name, "r") as f:
metas = [line.strip() for line in f]
num_of_files = len(metas)
print(num_of_files)
metas = random.sample(metas, random_samples)
num_of_files = len(metas)
print("after sample:", num_of_files)
return metas, num_of_files
def process_item(file, args):
with open(os.path.join(args.root, file)) as f:
anno = json.load(f)
if not check_caption(anno):
return None
noun_chunks = anno['noun_chunks']
ref_exps = anno['ref_exps']
regions = []
random_num = random.random()
if random_num > 0.5:
for nc in noun_chunks:
region = get_regions(nc, anno)
if str.isascii(region["phrase"]):
regions.append(region)
else:
for re in ref_exps:
region = get_regions(re, anno)
if str.isascii(region["phrase"]):
regions.append(region)
if len(regions) < args.min_phrase:
return None
odvg_anno = {
"filename": f'{file.split(".")[0]}.jpg',
"height": anno["height"],
"width": anno["width"],
"grounding": {
"caption": clean_span(anno["caption"]),
"regions": regions
}
}
return odvg_anno
if __name__ == "__main__":
# jsons = "/share_data/mllm/kosmos-2/GRIT-20M/anno/14m_anno.list"
# root = "/share_data/mllm/kosmos-2/GRIT-20M/data"
# output_name = "./girt_14m_odvg.jsonl"
parser = argparse.ArgumentParser(description="GRIT2ODVG List.")
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--root", type=str, default="", help="Source image root")
parser.add_argument("--output_file", type=str, default="girt_14m_odvg.jsonl")
parser.add_argument("--random_samples", type=int, default=200000)
parser.add_argument("--chunk_or_ref", type=float, default=0.5)
parser.add_argument("--min_phrase", type=int, default=6)
parser.add_argument("--process_num", type=int, default=10, help="the number of processes")
args = parser.parse_args()
print(args)
metas, metas_len = prepare_list(args.input_file, args.random_samples)
odvg_anno = []
func = partial(process_item, args=args)
with Pool(processes=args.process_num) as pool:
for result in tqdm(pool.imap(func=func, iterable=metas), total=len(metas)):
odvg_anno.append(result)
odvg_anno = list(filter(None, odvg_anno))
with jsonlines.open(args.output_file, mode="w") as fwriter:
fwriter.write_all(odvg_anno) |