|
|
|
import json |
|
import random |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
from mmyolo.registry import TRANSFORMS |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class RandomLoadText: |
|
|
|
def __init__(self, |
|
text_path: str = None, |
|
prompt_format: str = '{}', |
|
num_neg_samples: Tuple[int, int] = (80, 80), |
|
max_num_samples: int = 80, |
|
padding_to_max: bool = False, |
|
padding_value: str = '') -> None: |
|
self.prompt_format = prompt_format |
|
self.num_neg_samples = num_neg_samples |
|
self.max_num_samples = max_num_samples |
|
self.padding_to_max = padding_to_max |
|
self.padding_value = padding_value |
|
if text_path is not None: |
|
with open(text_path, 'r') as f: |
|
self.class_texts = json.load(f) |
|
|
|
def __call__(self, results: dict) -> dict: |
|
assert 'texts' in results or hasattr(self, 'class_texts'), ( |
|
'No texts found in results.') |
|
class_texts = results.get( |
|
'texts', |
|
getattr(self, 'class_texts', None)) |
|
|
|
num_classes = len(class_texts) |
|
if 'gt_labels' in results: |
|
gt_label_tag = 'gt_labels' |
|
elif 'gt_bboxes_labels' in results: |
|
gt_label_tag = 'gt_bboxes_labels' |
|
else: |
|
raise ValueError('No valid labels found in results.') |
|
positive_labels = set(results[gt_label_tag]) |
|
|
|
if len(positive_labels) > self.max_num_samples: |
|
positive_labels = set(random.sample(list(positive_labels), |
|
k=self.max_num_samples)) |
|
|
|
num_neg_samples = min( |
|
min(num_classes, self.max_num_samples) - len(positive_labels), |
|
random.randint(*self.num_neg_samples)) |
|
candidate_neg_labels = [] |
|
for idx in range(num_classes): |
|
if idx not in positive_labels: |
|
candidate_neg_labels.append(idx) |
|
negative_labels = random.sample( |
|
candidate_neg_labels, k=num_neg_samples) |
|
|
|
sampled_labels = list(positive_labels) + list(negative_labels) |
|
random.shuffle(sampled_labels) |
|
|
|
label2ids = {label: i for i, label in enumerate(sampled_labels)} |
|
|
|
gt_valid_mask = np.zeros(len(results['gt_bboxes']), dtype=bool) |
|
for idx, label in enumerate(results[gt_label_tag]): |
|
if label in label2ids: |
|
gt_valid_mask[idx] = True |
|
results[gt_label_tag][idx] = label2ids[label] |
|
results['gt_bboxes'] = results['gt_bboxes'][gt_valid_mask] |
|
results[gt_label_tag] = results[gt_label_tag][gt_valid_mask] |
|
|
|
if 'instances' in results: |
|
retaged_instances = [] |
|
for idx, inst in enumerate(results['instances']): |
|
label = inst['bbox_label'] |
|
if label in label2ids: |
|
inst['bbox_label'] = label2ids[label] |
|
retaged_instances.append(inst) |
|
results['instances'] = retaged_instances |
|
|
|
texts = [] |
|
for label in sampled_labels: |
|
cls_caps = class_texts[label] |
|
assert len(cls_caps) > 0 |
|
cap_id = random.randrange(len(cls_caps)) |
|
sel_cls_cap = self.prompt_format.format(cls_caps[cap_id]) |
|
texts.append(sel_cls_cap) |
|
|
|
if self.padding_to_max: |
|
num_valid_labels = len(positive_labels) + len(negative_labels) |
|
num_padding = self.max_num_samples - num_valid_labels |
|
if num_padding > 0: |
|
texts += [self.padding_value] * num_padding |
|
|
|
results['texts'] = texts |
|
|
|
return results |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class LoadText: |
|
|
|
def __init__(self, |
|
text_path: str = None, |
|
prompt_format: str = '{}', |
|
multi_prompt_flag: str = '/') -> None: |
|
self.prompt_format = prompt_format |
|
self.multi_prompt_flag = multi_prompt_flag |
|
if text_path is not None: |
|
with open(text_path, 'r') as f: |
|
self.class_texts = json.load(f) |
|
|
|
def __call__(self, results: dict) -> dict: |
|
assert 'texts' in results or hasattr(self, 'class_texts'), ( |
|
'No texts found in results.') |
|
class_texts = results.get( |
|
'texts', |
|
getattr(self, 'class_texts', None)) |
|
|
|
texts = [] |
|
for idx, cls_caps in enumerate(class_texts): |
|
assert len(cls_caps) > 0 |
|
sel_cls_cap = cls_caps[0] |
|
sel_cls_cap = self.prompt_format.format(sel_cls_cap) |
|
texts.append(sel_cls_cap) |
|
|
|
results['texts'] = texts |
|
|
|
return results |
|
|