Spaces:
Running
on
T4
Running
on
T4
import os | |
import random | |
from typing import List | |
import torch | |
def create_positive_map_from_span(tokenized, token_span, max_text_len=256): | |
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j | |
Input: | |
- tokenized: | |
- input_ids: Tensor[1, ntokens] | |
- attention_mask: Tensor[1, ntokens] | |
- token_span: list with length num_boxes. | |
- each item: [start_idx, end_idx] | |
""" | |
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) | |
for j, tok_list in enumerate(token_span): | |
for (beg, end) in tok_list: | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE": | |
positive_map[j, beg_pos] = 1 | |
break | |
else: | |
positive_map[j, beg_pos : end_pos + 1].fill_(1) | |
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
def build_captions_and_token_span(cat_list, force_lowercase): | |
""" | |
Return: | |
captions: str | |
cat2tokenspan: dict | |
{ | |
'dog': [[0, 2]], | |
... | |
} | |
""" | |
cat2tokenspan = {} | |
captions = "" | |
for catname in cat_list: | |
class_name = catname | |
if force_lowercase: | |
class_name = class_name.lower() | |
if "/" in class_name: | |
class_name_list: List = class_name.strip().split("/") | |
class_name_list.append(class_name) | |
class_name: str = random.choice(class_name_list) | |
tokens_positive_i = [] | |
subnamelist = [i.strip() for i in class_name.strip().split(" ")] | |
for subname in subnamelist: | |
if len(subname) == 0: | |
continue | |
if len(captions) > 0: | |
captions = captions + " " | |
strat_idx = len(captions) | |
end_idx = strat_idx + len(subname) | |
tokens_positive_i.append([strat_idx, end_idx]) | |
captions = captions + subname | |
if len(tokens_positive_i) > 0: | |
captions = captions + " ." | |
cat2tokenspan[class_name] = tokens_positive_i | |
return captions, cat2tokenspan | |
def build_id2posspan_and_caption(category_dict: dict): | |
"""Build id2pos_span and caption from category_dict | |
Args: | |
category_dict (dict): category_dict | |
""" | |
cat_list = [item["name"].lower() for item in category_dict] | |
id2catname = {item["id"]: item["name"].lower() for item in category_dict} | |
caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True) | |
id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()} | |
return id2posspan, caption | |