File size: 4,676 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) Tencent Inc. All rights reserved.
import json
import random
from typing import Tuple
import numpy as np
from mmyolo.registry import TRANSFORMS
@TRANSFORMS.register_module()
class RandomLoadText:
def __init__(self,
text_path: str = None,
prompt_format: str = '{}',
num_neg_samples: Tuple[int, int] = (80, 80),
max_num_samples: int = 80,
padding_to_max: bool = False,
padding_value: str = '') -> None:
self.prompt_format = prompt_format
self.num_neg_samples = num_neg_samples
self.max_num_samples = max_num_samples
self.padding_to_max = padding_to_max
self.padding_value = padding_value
if text_path is not None:
with open(text_path, 'r') as f:
self.class_texts = json.load(f)
def __call__(self, results: dict) -> dict:
assert 'texts' in results or hasattr(self, 'class_texts'), (
'No texts found in results.')
class_texts = results.get(
'texts',
getattr(self, 'class_texts', None))
num_classes = len(class_texts)
if 'gt_labels' in results:
gt_label_tag = 'gt_labels'
elif 'gt_bboxes_labels' in results:
gt_label_tag = 'gt_bboxes_labels'
else:
raise ValueError('No valid labels found in results.')
positive_labels = set(results[gt_label_tag])
if len(positive_labels) > self.max_num_samples:
positive_labels = set(random.sample(list(positive_labels),
k=self.max_num_samples))
num_neg_samples = min(
min(num_classes, self.max_num_samples) - len(positive_labels),
random.randint(*self.num_neg_samples))
candidate_neg_labels = []
for idx in range(num_classes):
if idx not in positive_labels:
candidate_neg_labels.append(idx)
negative_labels = random.sample(
candidate_neg_labels, k=num_neg_samples)
sampled_labels = list(positive_labels) + list(negative_labels)
random.shuffle(sampled_labels)
label2ids = {label: i for i, label in enumerate(sampled_labels)}
gt_valid_mask = np.zeros(len(results['gt_bboxes']), dtype=bool)
for idx, label in enumerate(results[gt_label_tag]):
if label in label2ids:
gt_valid_mask[idx] = True
results[gt_label_tag][idx] = label2ids[label]
results['gt_bboxes'] = results['gt_bboxes'][gt_valid_mask]
results[gt_label_tag] = results[gt_label_tag][gt_valid_mask]
if 'instances' in results:
retaged_instances = []
for idx, inst in enumerate(results['instances']):
label = inst['bbox_label']
if label in label2ids:
inst['bbox_label'] = label2ids[label]
retaged_instances.append(inst)
results['instances'] = retaged_instances
texts = []
for label in sampled_labels:
cls_caps = class_texts[label]
assert len(cls_caps) > 0
cap_id = random.randrange(len(cls_caps))
sel_cls_cap = self.prompt_format.format(cls_caps[cap_id])
texts.append(sel_cls_cap)
if self.padding_to_max:
num_valid_labels = len(positive_labels) + len(negative_labels)
num_padding = self.max_num_samples - num_valid_labels
if num_padding > 0:
texts += [self.padding_value] * num_padding
results['texts'] = texts
return results
@TRANSFORMS.register_module()
class LoadText:
def __init__(self,
text_path: str = None,
prompt_format: str = '{}',
multi_prompt_flag: str = '/') -> None:
self.prompt_format = prompt_format
self.multi_prompt_flag = multi_prompt_flag
if text_path is not None:
with open(text_path, 'r') as f:
self.class_texts = json.load(f)
def __call__(self, results: dict) -> dict:
assert 'texts' in results or hasattr(self, 'class_texts'), (
'No texts found in results.')
class_texts = results.get(
'texts',
getattr(self, 'class_texts', None))
texts = []
for idx, cls_caps in enumerate(class_texts):
assert len(cls_caps) > 0
sel_cls_cap = cls_caps[0]
sel_cls_cap = self.prompt_format.format(sel_cls_cap)
texts.append(sel_cls_cap)
results['texts'] = texts
return results
|