Spaces:
Runtime error
Runtime error
File size: 2,835 Bytes
3a0062c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
the difference from what I can tell is I use CrossEntropy for the classes
instead of BinaryCrossEntropy.
"""
import random
import pytorch_lightning as pl
import torch
import torch.nn as nn
from .utils import intersection_over_union
class YoloLoss(pl.LightningModule):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
self.entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
# constants for the loss function
self.lambda_class = 1
self.lambda_noobj = 5
self.lambda_obj = 1
self.lambda_box = 1
def forward(self, predictions, target, anchors):
# Check where obj and noobj (we ignore if target == -1)
obj = target[..., 0] == 1
noobj = target[..., 0] == 0
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
no_object_loss = self.bce(
(predictions[..., 0:1][noobj]),
(target[..., 0:1][noobj])
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
anchors = anchors.reshape(1, 3, 1, 1, 2)
box_preds = torch.cat(
[
self.sigmoid(predictions[..., 1:3]),
torch.exp(predictions[..., 3:5]) * anchors,
],
dim=-1,
)
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
object_loss = self.mse(
self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj]
)
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates
target[..., 3:5] = torch.log(
(1e-16 + target[..., 3:5] / anchors)
) # width, height coordinates
box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
# ================== #
# FOR CLASS LOSS #
# ================== #
class_loss = self.entropy(
(predictions[..., 5:][obj]),
(target[..., 5][obj].long()),
)
# print("__________________________________")
# print(self.lambda_box * box_loss)
# print(self.lambda_obj * object_loss)
# print(self.lambda_noobj * no_object_loss)
# print(self.lambda_class * class_loss)
# print("\n")
return (
self.lambda_box * box_loss
+ self.lambda_obj * object_loss
+ self.lambda_noobj * no_object_loss
+ self.lambda_class * class_loss
)
|