File size: 3,861 Bytes
a277bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import jsonlines
from tqdm import tqdm
import random
import json
import os
from multiprocessing import Pool
from functools import partial
import emoji

import argparse

def clean_span(span):
    span = span.rstrip()
    span = span.replace('"', "'").replace('\"', "'").replace('β€œ', "'").replace('”', "'")
    span = span.replace('β€˜', "'").replace('’', "'").replace('–', "β€”")
    if span.endswith('/') or span.endswith('.'):
        span = span[:-1]
    return span

def check_caption(cap):
    check_anno = cap["caption"].rstrip()[:-1]
    if not str.isascii(check_anno):
        return False
    # "The view is better from here πŸ¦… (Chouf" wtf??
    check_list = {"↙️", "-", ",", "Β ", "*", "/", "$", "[CLS]", "[SEP]", "?"}
    for ch in check_list:
        if ch in check_anno: 
            return False
    if '.' in check_anno[:-1]:
        return False
    if emoji.emoji_count(check_anno):
        print(check_anno)
        return False
    return True

def get_regions(nc, anno):
    h = anno["height"]
    w = anno["width"]
    phrase = clean_span(anno["caption"][int(nc[0]):int(nc[1])])
    bbox = [round(nc[2]*w,2), round(nc[3]*h,2), round(nc[4]*w,2), round(nc[5]*h,2)]
    return {
        "bbox": bbox,
        "phrase": phrase
    }


def prepare_list(file_name: str, random_samples):
    with open(file_name, "r") as f:
        metas = [line.strip() for line in f]
    num_of_files = len(metas)
    print(num_of_files)
    metas = random.sample(metas, random_samples)
    num_of_files = len(metas)
    print("after sample:", num_of_files)
    return metas, num_of_files


def process_item(file, args):
    with open(os.path.join(args.root, file)) as f:
        anno = json.load(f)
    if not check_caption(anno):
        return None
    noun_chunks = anno['noun_chunks']
    ref_exps = anno['ref_exps']
    regions = []
    random_num = random.random()
    if random_num > 0.5:
        for nc in noun_chunks:
            region = get_regions(nc, anno)
            if str.isascii(region["phrase"]):
                regions.append(region)
    else:
        for re in ref_exps:
            region = get_regions(re, anno)
            if str.isascii(region["phrase"]):
                regions.append(region)
    if len(regions) < args.min_phrase:
        return None
    odvg_anno = {
        "filename": f'{file.split(".")[0]}.jpg',
        "height": anno["height"],
        "width": anno["width"],
        "grounding": { 
            "caption": clean_span(anno["caption"]),
            "regions": regions
        }
    }
    return odvg_anno

if __name__ == "__main__":
    # jsons = "/share_data/mllm/kosmos-2/GRIT-20M/anno/14m_anno.list"
    # root = "/share_data/mllm/kosmos-2/GRIT-20M/data"
    # output_name = "./girt_14m_odvg.jsonl"
    parser = argparse.ArgumentParser(description="GRIT2ODVG List.")
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--root", type=str, default="", help="Source image root")
    parser.add_argument("--output_file", type=str, default="girt_14m_odvg.jsonl")
    parser.add_argument("--random_samples", type=int, default=200000)
    parser.add_argument("--chunk_or_ref", type=float, default=0.5)
    parser.add_argument("--min_phrase", type=int, default=6)
    parser.add_argument("--process_num", type=int, default=10, help="the number of processes")
    args = parser.parse_args()
    print(args)
    metas, metas_len = prepare_list(args.input_file, args.random_samples)
    odvg_anno = []
    func = partial(process_item, args=args)
    with Pool(processes=args.process_num) as pool:
        for result in tqdm(pool.imap(func=func, iterable=metas), total=len(metas)):
            odvg_anno.append(result)
    odvg_anno = list(filter(None, odvg_anno))  
    with jsonlines.open(args.output_file, mode="w") as fwriter:
        fwriter.write_all(odvg_anno)