|
|
|
|
|
import contextlib |
|
from copy import deepcopy |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ultralytics.nn.modules import ( |
|
AIFI, |
|
C1, |
|
C2, |
|
C3, |
|
C3TR, |
|
OBB, |
|
SPP, |
|
SPPF, |
|
Bottleneck, |
|
BottleneckCSP, |
|
C2f, |
|
C2fAttn, |
|
ImagePoolingAttn, |
|
C3Ghost, |
|
C3x, |
|
Classify, |
|
Concat, |
|
Conv, |
|
Conv2, |
|
ConvTranspose, |
|
Detect, |
|
DWConv, |
|
DWConvTranspose2d, |
|
Focus, |
|
GhostBottleneck, |
|
GhostConv, |
|
HGBlock, |
|
HGStem, |
|
Pose, |
|
RepC3, |
|
RepConv, |
|
ResNetLayer, |
|
RTDETRDecoder, |
|
Segment, |
|
WorldDetect, |
|
RepNCSPELAN4, |
|
ADown, |
|
SPPELAN, |
|
CBFuse, |
|
CBLinear, |
|
Silence, |
|
C2fCIB, |
|
PSA, |
|
SCDown, |
|
RepVGGDW, |
|
v10Detect |
|
) |
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load |
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml |
|
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss |
|
from ultralytics.utils.plotting import feature_visualization |
|
from ultralytics.utils.torch_utils import ( |
|
fuse_conv_and_bn, |
|
fuse_deconv_and_bn, |
|
initialize_weights, |
|
intersect_dicts, |
|
make_divisible, |
|
model_info, |
|
scale_img, |
|
time_sync, |
|
) |
|
|
|
try: |
|
import thop |
|
except ImportError: |
|
thop = None |
|
|
|
|
|
class BaseModel(nn.Module): |
|
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" |
|
|
|
def forward(self, x, *args, **kwargs): |
|
""" |
|
Forward pass of the model on a single scale. Wrapper for `_forward_once` method. |
|
|
|
Args: |
|
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels. |
|
|
|
Returns: |
|
(torch.Tensor): The output of the network. |
|
""" |
|
if isinstance(x, dict): |
|
return self.loss(x, *args, **kwargs) |
|
return self.predict(x, *args, **kwargs) |
|
|
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None): |
|
""" |
|
Perform a forward pass through the network. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor to the model. |
|
profile (bool): Print the computation time of each layer if True, defaults to False. |
|
visualize (bool): Save the feature maps of the model if True, defaults to False. |
|
augment (bool): Augment image during prediction, defaults to False. |
|
embed (list, optional): A list of feature vectors/embeddings to return. |
|
|
|
Returns: |
|
(torch.Tensor): The last output of the model. |
|
""" |
|
if augment: |
|
return self._predict_augment(x) |
|
return self._predict_once(x, profile, visualize, embed) |
|
|
|
def _predict_once(self, x, profile=False, visualize=False, embed=None): |
|
""" |
|
Perform a forward pass through the network. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor to the model. |
|
profile (bool): Print the computation time of each layer if True, defaults to False. |
|
visualize (bool): Save the feature maps of the model if True, defaults to False. |
|
embed (list, optional): A list of feature vectors/embeddings to return. |
|
|
|
Returns: |
|
(torch.Tensor): The last output of the model. |
|
""" |
|
y, dt, embeddings = [], [], [] |
|
for m in self.model: |
|
if m.f != -1: |
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] |
|
if profile: |
|
self._profile_one_layer(m, x, dt) |
|
x = m(x) |
|
y.append(x if m.i in self.save else None) |
|
if visualize: |
|
feature_visualization(x, m.type, m.i, save_dir=visualize) |
|
if embed and m.i in embed: |
|
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) |
|
if m.i == max(embed): |
|
return torch.unbind(torch.cat(embeddings, 1), dim=0) |
|
return x |
|
|
|
def _predict_augment(self, x): |
|
"""Perform augmentations on input image x and return augmented inference.""" |
|
LOGGER.warning( |
|
f"WARNING β οΈ {self.__class__.__name__} does not support augmented inference yet. " |
|
f"Reverting to single-scale inference instead." |
|
) |
|
return self._predict_once(x) |
|
|
|
def _profile_one_layer(self, m, x, dt): |
|
""" |
|
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to |
|
the provided list. |
|
|
|
Args: |
|
m (nn.Module): The layer to be profiled. |
|
x (torch.Tensor): The input data to the layer. |
|
dt (list): A list to store the computation time of the layer. |
|
|
|
Returns: |
|
None |
|
""" |
|
c = m == self.model[-1] and isinstance(x, list) |
|
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 |
|
t = time_sync() |
|
for _ in range(10): |
|
m(x.copy() if c else x) |
|
dt.append((time_sync() - t) * 100) |
|
if m == self.model[0]: |
|
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") |
|
LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}") |
|
if c: |
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") |
|
|
|
def fuse(self, verbose=True): |
|
""" |
|
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the |
|
computation efficiency. |
|
|
|
Returns: |
|
(nn.Module): The fused model is returned. |
|
""" |
|
if not self.is_fused(): |
|
for m in self.model.modules(): |
|
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): |
|
if isinstance(m, Conv2): |
|
m.fuse_convs() |
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) |
|
delattr(m, "bn") |
|
m.forward = m.forward_fuse |
|
if isinstance(m, ConvTranspose) and hasattr(m, "bn"): |
|
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) |
|
delattr(m, "bn") |
|
m.forward = m.forward_fuse |
|
if isinstance(m, RepConv): |
|
m.fuse_convs() |
|
m.forward = m.forward_fuse |
|
if isinstance(m, RepVGGDW): |
|
m.fuse() |
|
m.forward = m.forward_fuse |
|
self.info(verbose=verbose) |
|
|
|
return self |
|
|
|
def is_fused(self, thresh=10): |
|
""" |
|
Check if the model has less than a certain threshold of BatchNorm layers. |
|
|
|
Args: |
|
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10. |
|
|
|
Returns: |
|
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. |
|
""" |
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) |
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh |
|
|
|
def info(self, detailed=False, verbose=True, imgsz=640): |
|
""" |
|
Prints model information. |
|
|
|
Args: |
|
detailed (bool): if True, prints out detailed information about the model. Defaults to False |
|
verbose (bool): if True, prints out the model information. Defaults to False |
|
imgsz (int): the size of the image that the model will be trained on. Defaults to 640 |
|
""" |
|
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz) |
|
|
|
def _apply(self, fn): |
|
""" |
|
Applies a function to all the tensors in the model that are not parameters or registered buffers. |
|
|
|
Args: |
|
fn (function): the function to apply to the model |
|
|
|
Returns: |
|
(BaseModel): An updated BaseModel object. |
|
""" |
|
self = super()._apply(fn) |
|
m = self.model[-1] |
|
if isinstance(m, Detect): |
|
m.stride = fn(m.stride) |
|
m.anchors = fn(m.anchors) |
|
m.strides = fn(m.strides) |
|
return self |
|
|
|
def load(self, weights, verbose=True): |
|
""" |
|
Load the weights into the model. |
|
|
|
Args: |
|
weights (dict | torch.nn.Module): The pre-trained weights to be loaded. |
|
verbose (bool, optional): Whether to log the transfer progress. Defaults to True. |
|
""" |
|
model = weights["model"] if isinstance(weights, dict) else weights |
|
csd = model.float().state_dict() |
|
csd = intersect_dicts(csd, self.state_dict()) |
|
self.load_state_dict(csd, strict=False) |
|
if verbose: |
|
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") |
|
|
|
def loss(self, batch, preds=None): |
|
""" |
|
Compute loss. |
|
|
|
Args: |
|
batch (dict): Batch to compute loss on |
|
preds (torch.Tensor | List[torch.Tensor]): Predictions. |
|
""" |
|
if not hasattr(self, "criterion"): |
|
self.criterion = self.init_criterion() |
|
|
|
preds = self.forward(batch["img"]) if preds is None else preds |
|
return self.criterion(preds, batch) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the BaseModel.""" |
|
raise NotImplementedError("compute_loss() needs to be implemented by task heads") |
|
|
|
|
|
class DetectionModel(BaseModel): |
|
"""YOLOv8 detection model.""" |
|
|
|
def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): |
|
"""Initialize the YOLOv8 detection model with the given config and parameters.""" |
|
super().__init__() |
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) |
|
|
|
|
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) |
|
if nc and nc != self.yaml["nc"]: |
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") |
|
self.yaml["nc"] = nc |
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) |
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} |
|
self.inplace = self.yaml.get("inplace", True) |
|
|
|
|
|
m = self.model[-1] |
|
if isinstance(m, Detect): |
|
s = 256 |
|
m.inplace = self.inplace |
|
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) |
|
if isinstance(m, v10Detect): |
|
forward = lambda x: self.forward(x)["one2many"] |
|
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) |
|
self.stride = m.stride |
|
m.bias_init() |
|
else: |
|
self.stride = torch.Tensor([32]) |
|
|
|
|
|
initialize_weights(self) |
|
if verbose: |
|
self.info() |
|
LOGGER.info("") |
|
|
|
def _predict_augment(self, x): |
|
"""Perform augmentations on input image x and return augmented inference and train outputs.""" |
|
img_size = x.shape[-2:] |
|
s = [1, 0.83, 0.67] |
|
f = [None, 3, None] |
|
y = [] |
|
for si, fi in zip(s, f): |
|
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) |
|
yi = super().predict(xi) |
|
if isinstance(yi, dict): |
|
yi = yi["one2one"] |
|
if isinstance(yi, (list, tuple)): |
|
yi = yi[0] |
|
yi = self._descale_pred(yi, fi, si, img_size) |
|
y.append(yi) |
|
y = self._clip_augmented(y) |
|
return torch.cat(y, -1), None |
|
|
|
@staticmethod |
|
def _descale_pred(p, flips, scale, img_size, dim=1): |
|
"""De-scale predictions following augmented inference (inverse operation).""" |
|
p[:, :4] /= scale |
|
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim) |
|
if flips == 2: |
|
y = img_size[0] - y |
|
elif flips == 3: |
|
x = img_size[1] - x |
|
return torch.cat((x, y, wh, cls), dim) |
|
|
|
def _clip_augmented(self, y): |
|
"""Clip YOLO augmented inference tails.""" |
|
nl = self.model[-1].nl |
|
g = sum(4**x for x in range(nl)) |
|
e = 1 |
|
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) |
|
y[0] = y[0][..., :-i] |
|
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) |
|
y[-1] = y[-1][..., i:] |
|
return y |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the DetectionModel.""" |
|
return v8DetectionLoss(self) |
|
|
|
|
|
class OBBModel(DetectionModel): |
|
"""YOLOv8 Oriented Bounding Box (OBB) model.""" |
|
|
|
def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True): |
|
"""Initialize YOLOv8 OBB model with given config and parameters.""" |
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the model.""" |
|
return v8OBBLoss(self) |
|
|
|
|
|
class SegmentationModel(DetectionModel): |
|
"""YOLOv8 segmentation model.""" |
|
|
|
def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True): |
|
"""Initialize YOLOv8 segmentation model with given config and parameters.""" |
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the SegmentationModel.""" |
|
return v8SegmentationLoss(self) |
|
|
|
|
|
class PoseModel(DetectionModel): |
|
"""YOLOv8 pose model.""" |
|
|
|
def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): |
|
"""Initialize YOLOv8 Pose model.""" |
|
if not isinstance(cfg, dict): |
|
cfg = yaml_model_load(cfg) |
|
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): |
|
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") |
|
cfg["kpt_shape"] = data_kpt_shape |
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the PoseModel.""" |
|
return v8PoseLoss(self) |
|
|
|
|
|
class ClassificationModel(BaseModel): |
|
"""YOLOv8 classification model.""" |
|
|
|
def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): |
|
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" |
|
super().__init__() |
|
self._from_yaml(cfg, ch, nc, verbose) |
|
|
|
def _from_yaml(self, cfg, ch, nc, verbose): |
|
"""Set YOLOv8 model configurations and define the model architecture.""" |
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) |
|
|
|
|
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) |
|
if nc and nc != self.yaml["nc"]: |
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") |
|
self.yaml["nc"] = nc |
|
elif not nc and not self.yaml.get("nc", None): |
|
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") |
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) |
|
self.stride = torch.Tensor([1]) |
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} |
|
self.info() |
|
|
|
@staticmethod |
|
def reshape_outputs(model, nc): |
|
"""Update a TorchVision classification model to class count 'n' if required.""" |
|
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] |
|
if isinstance(m, Classify): |
|
if m.linear.out_features != nc: |
|
m.linear = nn.Linear(m.linear.in_features, nc) |
|
elif isinstance(m, nn.Linear): |
|
if m.out_features != nc: |
|
setattr(model, name, nn.Linear(m.in_features, nc)) |
|
elif isinstance(m, nn.Sequential): |
|
types = [type(x) for x in m] |
|
if nn.Linear in types: |
|
i = types.index(nn.Linear) |
|
if m[i].out_features != nc: |
|
m[i] = nn.Linear(m[i].in_features, nc) |
|
elif nn.Conv2d in types: |
|
i = types.index(nn.Conv2d) |
|
if m[i].out_channels != nc: |
|
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the ClassificationModel.""" |
|
return v8ClassificationLoss() |
|
|
|
|
|
class RTDETRDetectionModel(DetectionModel): |
|
""" |
|
RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. |
|
|
|
This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both |
|
the training and inference processes. RTDETR is an object detection and tracking model that extends from the |
|
DetectionModel base class. |
|
|
|
Attributes: |
|
cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'. |
|
ch (int): Number of input channels. Default is 3 (RGB). |
|
nc (int, optional): Number of classes for object detection. Default is None. |
|
verbose (bool): Specifies if summary statistics are shown during initialization. Default is True. |
|
|
|
Methods: |
|
init_criterion: Initializes the criterion used for loss calculation. |
|
loss: Computes and returns the loss during training. |
|
predict: Performs a forward pass through the network and returns the output. |
|
""" |
|
|
|
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): |
|
""" |
|
Initialize the RTDETRDetectionModel. |
|
|
|
Args: |
|
cfg (str): Configuration file name or path. |
|
ch (int): Number of input channels. |
|
nc (int, optional): Number of classes. Defaults to None. |
|
verbose (bool, optional): Print additional information during initialization. Defaults to True. |
|
""" |
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the RTDETRDetectionModel.""" |
|
from ultralytics.models.utils.loss import RTDETRDetectionLoss |
|
|
|
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True) |
|
|
|
def loss(self, batch, preds=None): |
|
""" |
|
Compute the loss for the given batch of data. |
|
|
|
Args: |
|
batch (dict): Dictionary containing image and label data. |
|
preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None. |
|
|
|
Returns: |
|
(tuple): A tuple containing the total loss and main three losses in a tensor. |
|
""" |
|
if not hasattr(self, "criterion"): |
|
self.criterion = self.init_criterion() |
|
|
|
img = batch["img"] |
|
|
|
bs = len(img) |
|
batch_idx = batch["batch_idx"] |
|
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] |
|
targets = { |
|
"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), |
|
"bboxes": batch["bboxes"].to(device=img.device), |
|
"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), |
|
"gt_groups": gt_groups, |
|
} |
|
|
|
preds = self.predict(img, batch=targets) if preds is None else preds |
|
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] |
|
if dn_meta is None: |
|
dn_bboxes, dn_scores = None, None |
|
else: |
|
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) |
|
dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) |
|
|
|
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) |
|
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) |
|
|
|
loss = self.criterion( |
|
(dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta |
|
) |
|
|
|
return sum(loss.values()), torch.as_tensor( |
|
[loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device |
|
) |
|
|
|
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): |
|
""" |
|
Perform a forward pass through the model. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. |
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. |
|
batch (dict, optional): Ground truth data for evaluation. Defaults to None. |
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. |
|
embed (list, optional): A list of feature vectors/embeddings to return. |
|
|
|
Returns: |
|
(torch.Tensor): Model's output tensor. |
|
""" |
|
y, dt, embeddings = [], [], [] |
|
for m in self.model[:-1]: |
|
if m.f != -1: |
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] |
|
if profile: |
|
self._profile_one_layer(m, x, dt) |
|
x = m(x) |
|
y.append(x if m.i in self.save else None) |
|
if visualize: |
|
feature_visualization(x, m.type, m.i, save_dir=visualize) |
|
if embed and m.i in embed: |
|
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) |
|
if m.i == max(embed): |
|
return torch.unbind(torch.cat(embeddings, 1), dim=0) |
|
head = self.model[-1] |
|
x = head([y[j] for j in head.f], batch) |
|
return x |
|
|
|
|
|
class WorldModel(DetectionModel): |
|
"""YOLOv8 World Model.""" |
|
|
|
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): |
|
"""Initialize YOLOv8 world model with given config and parameters.""" |
|
self.txt_feats = torch.randn(1, nc or 80, 512) |
|
self.clip_model = None |
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) |
|
|
|
def set_classes(self, text): |
|
"""Perform a forward pass with optional profiling, visualization, and embedding extraction.""" |
|
try: |
|
import clip |
|
except ImportError: |
|
check_requirements("git+https://github.com/openai/CLIP.git") |
|
import clip |
|
|
|
if not getattr(self, "clip_model", None): |
|
self.clip_model = clip.load("ViT-B/32")[0] |
|
device = next(self.clip_model.parameters()).device |
|
text_token = clip.tokenize(text).to(device) |
|
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32) |
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) |
|
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() |
|
self.model[-1].nc = len(text) |
|
|
|
def init_criterion(self): |
|
"""Initialize the loss criterion for the model.""" |
|
raise NotImplementedError |
|
|
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None): |
|
""" |
|
Perform a forward pass through the model. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. |
|
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. |
|
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. |
|
embed (list, optional): A list of feature vectors/embeddings to return. |
|
|
|
Returns: |
|
(torch.Tensor): Model's output tensor. |
|
""" |
|
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype) |
|
if len(txt_feats) != len(x): |
|
txt_feats = txt_feats.repeat(len(x), 1, 1) |
|
ori_txt_feats = txt_feats.clone() |
|
y, dt, embeddings = [], [], [] |
|
for m in self.model: |
|
if m.f != -1: |
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] |
|
if profile: |
|
self._profile_one_layer(m, x, dt) |
|
if isinstance(m, C2fAttn): |
|
x = m(x, txt_feats) |
|
elif isinstance(m, WorldDetect): |
|
x = m(x, ori_txt_feats) |
|
elif isinstance(m, ImagePoolingAttn): |
|
txt_feats = m(x, txt_feats) |
|
else: |
|
x = m(x) |
|
|
|
y.append(x if m.i in self.save else None) |
|
if visualize: |
|
feature_visualization(x, m.type, m.i, save_dir=visualize) |
|
if embed and m.i in embed: |
|
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) |
|
if m.i == max(embed): |
|
return torch.unbind(torch.cat(embeddings, 1), dim=0) |
|
return x |
|
|
|
class YOLOv10DetectionModel(DetectionModel): |
|
def init_criterion(self): |
|
return v10DetectLoss(self) |
|
|
|
class Ensemble(nn.ModuleList): |
|
"""Ensemble of models.""" |
|
|
|
def __init__(self): |
|
"""Initialize an ensemble of models.""" |
|
super().__init__() |
|
|
|
def forward(self, x, augment=False, profile=False, visualize=False): |
|
"""Function generates the YOLO network's final layer.""" |
|
y = [module(x, augment, profile, visualize)[0] for module in self] |
|
|
|
|
|
y = torch.cat(y, 2) |
|
return y, None |
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
def temporary_modules(modules=None): |
|
""" |
|
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`). |
|
|
|
This function can be used to change the module paths during runtime. It's useful when refactoring code, |
|
where you've moved a module from one location to another, but you still want to support the old import |
|
paths for backwards compatibility. |
|
|
|
Args: |
|
modules (dict, optional): A dictionary mapping old module paths to new module paths. |
|
|
|
Example: |
|
```python |
|
with temporary_modules({'old.module.path': 'new.module.path'}): |
|
import old.module.path # this will now import new.module.path |
|
``` |
|
|
|
Note: |
|
The changes are only in effect inside the context manager and are undone once the context manager exits. |
|
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger |
|
applications or libraries. Use this function with caution. |
|
""" |
|
if not modules: |
|
modules = {} |
|
|
|
import importlib |
|
import sys |
|
|
|
try: |
|
|
|
for old, new in modules.items(): |
|
sys.modules[old] = importlib.import_module(new) |
|
|
|
yield |
|
finally: |
|
|
|
for old in modules: |
|
if old in sys.modules: |
|
del sys.modules[old] |
|
|
|
|
|
def torch_safe_load(weight): |
|
""" |
|
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, |
|
it catches the error, logs a warning message, and attempts to install the missing module via the |
|
check_requirements() function. After installation, the function again attempts to load the model using torch.load(). |
|
|
|
Args: |
|
weight (str): The file path of the PyTorch model. |
|
|
|
Returns: |
|
(dict): The loaded PyTorch model. |
|
""" |
|
from ultralytics.utils.downloads import attempt_download_asset |
|
|
|
check_suffix(file=weight, suffix=".pt") |
|
file = attempt_download_asset(weight) |
|
try: |
|
with temporary_modules( |
|
{ |
|
"ultralytics.yolo.utils": "ultralytics.utils", |
|
"ultralytics.yolo.v8": "ultralytics.models.yolo", |
|
"ultralytics.yolo.data": "ultralytics.data", |
|
} |
|
): |
|
ckpt = torch.load(file, map_location="cpu") |
|
|
|
except ModuleNotFoundError as e: |
|
if e.name == "models": |
|
raise TypeError( |
|
emojis( |
|
f"ERROR βοΈ {weight} appears to be an Ultralytics YOLOv5 model originally trained " |
|
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " |
|
f"YOLOv8 at https://github.com/ultralytics/ultralytics." |
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " |
|
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" |
|
) |
|
) from e |
|
LOGGER.warning( |
|
f"WARNING β οΈ {weight} appears to require '{e.name}', which is not in ultralytics requirements." |
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." |
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " |
|
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" |
|
) |
|
check_requirements(e.name) |
|
ckpt = torch.load(file, map_location="cpu") |
|
|
|
if not isinstance(ckpt, dict): |
|
|
|
LOGGER.warning( |
|
f"WARNING β οΈ The file '{weight}' appears to be improperly saved or formatted. " |
|
f"For optimal results, use model.save('filename.pt') to correctly save YOLO models." |
|
) |
|
ckpt = {"model": ckpt.model} |
|
|
|
return ckpt, file |
|
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): |
|
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.""" |
|
|
|
ensemble = Ensemble() |
|
for w in weights if isinstance(weights, list) else [weights]: |
|
ckpt, w = torch_safe_load(w) |
|
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None |
|
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() |
|
|
|
|
|
model.args = args |
|
model.pt_path = w |
|
model.task = guess_model_task(model) |
|
if not hasattr(model, "stride"): |
|
model.stride = torch.tensor([32.0]) |
|
|
|
|
|
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) |
|
|
|
|
|
for m in ensemble.modules(): |
|
if hasattr(m, "inplace"): |
|
m.inplace = inplace |
|
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): |
|
m.recompute_scale_factor = None |
|
|
|
|
|
if len(ensemble) == 1: |
|
return ensemble[-1] |
|
|
|
|
|
LOGGER.info(f"Ensemble created with {weights}\n") |
|
for k in "names", "nc", "yaml": |
|
setattr(ensemble, k, getattr(ensemble[0], k)) |
|
ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride |
|
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" |
|
return ensemble |
|
|
|
|
|
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): |
|
"""Loads a single model weights.""" |
|
ckpt, weight = torch_safe_load(weight) |
|
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} |
|
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() |
|
|
|
|
|
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} |
|
model.pt_path = weight |
|
model.task = guess_model_task(model) |
|
if not hasattr(model, "stride"): |
|
model.stride = torch.tensor([32.0]) |
|
|
|
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() |
|
|
|
|
|
for m in model.modules(): |
|
if hasattr(m, "inplace"): |
|
m.inplace = inplace |
|
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): |
|
m.recompute_scale_factor = None |
|
|
|
|
|
return model, ckpt |
|
|
|
|
|
def parse_model(d, ch, verbose=True): |
|
"""Parse a YOLO model.yaml dictionary into a PyTorch model.""" |
|
import ast |
|
|
|
|
|
max_channels = float("inf") |
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) |
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) |
|
if scales: |
|
scale = d.get("scale") |
|
if not scale: |
|
scale = tuple(scales.keys())[0] |
|
LOGGER.warning(f"WARNING β οΈ no model scale passed. Assuming scale='{scale}'.") |
|
depth, width, max_channels = scales[scale] |
|
|
|
if act: |
|
Conv.default_act = eval(act) |
|
if verbose: |
|
LOGGER.info(f"{colorstr('activation:')} {act}") |
|
|
|
if verbose: |
|
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") |
|
ch = [ch] |
|
layers, save, c2 = [], [], ch[-1] |
|
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): |
|
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] |
|
for j, a in enumerate(args): |
|
if isinstance(a, str): |
|
with contextlib.suppress(ValueError): |
|
args[j] = locals()[a] if a in locals() else ast.literal_eval(a) |
|
|
|
n = n_ = max(round(n * depth), 1) if n > 1 else n |
|
if m in { |
|
Classify, |
|
Conv, |
|
ConvTranspose, |
|
GhostConv, |
|
Bottleneck, |
|
GhostBottleneck, |
|
SPP, |
|
SPPF, |
|
DWConv, |
|
Focus, |
|
BottleneckCSP, |
|
C1, |
|
C2, |
|
C2f, |
|
RepNCSPELAN4, |
|
ADown, |
|
SPPELAN, |
|
C2fAttn, |
|
C3, |
|
C3TR, |
|
C3Ghost, |
|
nn.ConvTranspose2d, |
|
DWConvTranspose2d, |
|
C3x, |
|
RepC3, |
|
PSA, |
|
SCDown, |
|
C2fCIB |
|
}: |
|
c1, c2 = ch[f], args[0] |
|
if c2 != nc: |
|
c2 = make_divisible(min(c2, max_channels) * width, 8) |
|
if m is C2fAttn: |
|
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) |
|
args[2] = int( |
|
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2] |
|
) |
|
|
|
args = [c1, c2, *args[1:]] |
|
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB): |
|
args.insert(2, n) |
|
n = 1 |
|
elif m is AIFI: |
|
args = [ch[f], *args] |
|
elif m in {HGStem, HGBlock}: |
|
c1, cm, c2 = ch[f], args[0], args[1] |
|
args = [c1, cm, c2, *args[2:]] |
|
if m is HGBlock: |
|
args.insert(4, n) |
|
n = 1 |
|
elif m is ResNetLayer: |
|
c2 = args[1] if args[3] else args[1] * 4 |
|
elif m is nn.BatchNorm2d: |
|
args = [ch[f]] |
|
elif m is Concat: |
|
c2 = sum(ch[x] for x in f) |
|
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: |
|
args.append([ch[x] for x in f]) |
|
if m is Segment: |
|
args[2] = make_divisible(min(args[2], max_channels) * width, 8) |
|
elif m is RTDETRDecoder: |
|
args.insert(1, [ch[x] for x in f]) |
|
elif m is CBLinear: |
|
c2 = args[0] |
|
c1 = ch[f] |
|
args = [c1, c2, *args[1:]] |
|
elif m is CBFuse: |
|
c2 = ch[f[-1]] |
|
else: |
|
c2 = ch[f] |
|
|
|
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) |
|
t = str(m)[8:-2].replace("__main__.", "") |
|
m.np = sum(x.numel() for x in m_.parameters()) |
|
m_.i, m_.f, m_.type = i, f, t |
|
if verbose: |
|
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") |
|
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) |
|
layers.append(m_) |
|
if i == 0: |
|
ch = [] |
|
ch.append(c2) |
|
return nn.Sequential(*layers), sorted(save) |
|
|
|
|
|
def yaml_model_load(path): |
|
"""Load a YOLOv8 model from a YAML file.""" |
|
import re |
|
|
|
path = Path(path) |
|
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): |
|
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) |
|
LOGGER.warning(f"WARNING β οΈ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") |
|
path = path.with_name(new_stem + path.suffix) |
|
|
|
if "v10" not in str(path): |
|
unified_path = re.sub(r"(\d+)([nsblmx])(.+)?$", r"\1\3", str(path)) |
|
else: |
|
unified_path = path |
|
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) |
|
d = yaml_load(yaml_file) |
|
d["scale"] = guess_model_scale(path) |
|
d["yaml_file"] = str(path) |
|
return d |
|
|
|
|
|
def guess_model_scale(model_path): |
|
""" |
|
Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function |
|
uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by |
|
n, s, m, l, or x. The function returns the size character of the model scale as a string. |
|
|
|
Args: |
|
model_path (str | Path): The path to the YOLO model's YAML file. |
|
|
|
Returns: |
|
(str): The size character of the model's scale, which can be n, s, m, l, or x. |
|
""" |
|
with contextlib.suppress(AttributeError): |
|
import re |
|
|
|
return re.search(r"yolov\d+([nsblmx])", Path(model_path).stem).group(1) |
|
return "" |
|
|
|
|
|
def guess_model_task(model): |
|
""" |
|
Guess the task of a PyTorch model from its architecture or configuration. |
|
|
|
Args: |
|
model (nn.Module | dict): PyTorch model or model configuration in YAML format. |
|
|
|
Returns: |
|
(str): Task of the model ('detect', 'segment', 'classify', 'pose'). |
|
|
|
Raises: |
|
SyntaxError: If the task of the model could not be determined. |
|
""" |
|
|
|
def cfg2task(cfg): |
|
"""Guess from YAML dictionary.""" |
|
m = cfg["head"][-1][-2].lower() |
|
if m in {"classify", "classifier", "cls", "fc"}: |
|
return "classify" |
|
if m == "detect" or m == "v10detect": |
|
return "detect" |
|
if m == "segment": |
|
return "segment" |
|
if m == "pose": |
|
return "pose" |
|
if m == "obb": |
|
return "obb" |
|
|
|
|
|
if isinstance(model, dict): |
|
with contextlib.suppress(Exception): |
|
return cfg2task(model) |
|
|
|
|
|
if isinstance(model, nn.Module): |
|
for x in "model.args", "model.model.args", "model.model.model.args": |
|
with contextlib.suppress(Exception): |
|
return eval(x)["task"] |
|
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": |
|
with contextlib.suppress(Exception): |
|
return cfg2task(eval(x)) |
|
|
|
for m in model.modules(): |
|
if isinstance(m, Segment): |
|
return "segment" |
|
elif isinstance(m, Classify): |
|
return "classify" |
|
elif isinstance(m, Pose): |
|
return "pose" |
|
elif isinstance(m, OBB): |
|
return "obb" |
|
elif isinstance(m, (Detect, WorldDetect, v10Detect)): |
|
return "detect" |
|
|
|
|
|
if isinstance(model, (str, Path)): |
|
model = Path(model) |
|
if "-seg" in model.stem or "segment" in model.parts: |
|
return "segment" |
|
elif "-cls" in model.stem or "classify" in model.parts: |
|
return "classify" |
|
elif "-pose" in model.stem or "pose" in model.parts: |
|
return "pose" |
|
elif "-obb" in model.stem or "obb" in model.parts: |
|
return "obb" |
|
elif "detect" in model.parts: |
|
return "detect" |
|
|
|
|
|
LOGGER.warning( |
|
"WARNING β οΈ Unable to automatically guess model task, assuming 'task=detect'. " |
|
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." |
|
) |
|
return "detect" |
|
|