|
|
|
from typing import List, Tuple, Union |
|
from torch import Tensor |
|
from mmdet.structures import OptSampleList, SampleList |
|
from mmyolo.models.detectors import YOLODetector |
|
from mmyolo.registry import MODELS |
|
|
|
|
|
@MODELS.register_module() |
|
class YOLOWorldDetector(YOLODetector): |
|
"""Implementation of YOLOW Series""" |
|
def __init__(self, |
|
*args, |
|
mm_neck: bool = False, |
|
num_train_classes=80, |
|
num_test_classes=80, |
|
**kwargs) -> None: |
|
self.mm_neck = mm_neck |
|
self.num_train_classes = num_train_classes |
|
self.num_test_classes = num_test_classes |
|
super().__init__(*args, **kwargs) |
|
|
|
def loss(self, batch_inputs: Tensor, |
|
batch_data_samples: SampleList) -> Union[dict, list]: |
|
"""Calculate losses from a batch of inputs and data samples.""" |
|
self.bbox_head.num_classes = self.num_train_classes |
|
img_feats, txt_feats = self.extract_feat(batch_inputs, |
|
batch_data_samples) |
|
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples) |
|
return losses |
|
|
|
def predict(self, |
|
batch_inputs: Tensor, |
|
batch_data_samples: SampleList, |
|
rescale: bool = True) -> SampleList: |
|
"""Predict results from a batch of inputs and data samples with post- |
|
processing. |
|
""" |
|
|
|
img_feats, txt_feats = self.extract_feat(batch_inputs, |
|
batch_data_samples) |
|
|
|
|
|
self.bbox_head.num_classes = txt_feats[0].shape[0] |
|
results_list = self.bbox_head.predict(img_feats, |
|
txt_feats, |
|
batch_data_samples, |
|
rescale=rescale) |
|
|
|
batch_data_samples = self.add_pred_to_datasample( |
|
batch_data_samples, results_list) |
|
return batch_data_samples |
|
|
|
def reparameterize(self, texts: List[List[str]]) -> None: |
|
self.texts = texts |
|
|
|
def _forward( |
|
self, |
|
batch_inputs: Tensor, |
|
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: |
|
"""Network forward process. Usually includes backbone, neck and head |
|
forward without any post-processing. |
|
""" |
|
img_feats, txt_feats = self.extract_feat(batch_inputs, |
|
batch_data_samples) |
|
results = self.bbox_head.forward(img_feats, txt_feats) |
|
return results |
|
|
|
def extract_feat( |
|
self, batch_inputs: Tensor, |
|
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]: |
|
"""Extract features.""" |
|
if batch_data_samples is None: |
|
texts = self.texts |
|
elif isinstance(batch_data_samples, dict): |
|
texts = batch_data_samples['texts'] |
|
elif isinstance(batch_data_samples, list): |
|
texts = [data_sample.texts for data_sample in batch_data_samples] |
|
else: |
|
raise TypeError('batch_data_samples should be dict or list.') |
|
|
|
img_feats, txt_feats = self.backbone(batch_inputs, texts) |
|
if self.with_neck: |
|
if self.mm_neck: |
|
img_feats = self.neck(img_feats, txt_feats) |
|
else: |
|
img_feats = self.neck(img_feats) |
|
return img_feats, txt_feats |
|
|