Spaces:
Running
Running
""" | |
Loss function implementations. | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from kornia.geometry import warp_perspective | |
from ..misc.geometry_utils import keypoints_to_grid, get_dist_mask, get_common_line_mask | |
def get_loss_and_weights(model_cfg, device=torch.device("cuda")): | |
"""Get loss functions and either static or dynamic weighting.""" | |
# Get the global weighting policy | |
w_policy = model_cfg.get("weighting_policy", "static") | |
if not w_policy in ["static", "dynamic"]: | |
raise ValueError("[Error] Not supported weighting policy.") | |
loss_func = {} | |
loss_weight = {} | |
# Get junction loss function and weight | |
w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy) | |
loss_func["junc_loss"] = junc_loss_func.to(device) | |
loss_weight["w_junc"] = w_junc | |
# Get heatmap loss function and weight | |
w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight( | |
model_cfg, w_policy, device | |
) | |
loss_func["heatmap_loss"] = heatmap_loss_func.to(device) | |
loss_weight["w_heatmap"] = w_heatmap | |
# [Optionally] get descriptor loss function and weight | |
if model_cfg.get("descriptor_loss_func", None) is not None: | |
w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight( | |
model_cfg, w_policy | |
) | |
loss_func["descriptor_loss"] = descriptor_loss_func.to(device) | |
loss_weight["w_desc"] = w_descriptor | |
return loss_func, loss_weight | |
def get_junction_loss_and_weight(model_cfg, global_w_policy): | |
"""Get the junction loss function and weight.""" | |
junction_loss_cfg = model_cfg.get("junction_loss_cfg", {}) | |
# Get the junction loss weight | |
w_policy = junction_loss_cfg.get("policy", global_w_policy) | |
if w_policy == "static": | |
w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32) | |
elif w_policy == "dynamic": | |
w_junc = nn.Parameter( | |
torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True | |
) | |
else: | |
raise ValueError("[Error] Unknown weighting policy for junction loss weight.") | |
# Get the junction loss function | |
junc_loss_name = model_cfg.get("junction_loss_func", "superpoint") | |
if junc_loss_name == "superpoint": | |
junc_loss_func = JunctionDetectionLoss( | |
model_cfg["grid_size"], model_cfg["keep_border_valid"] | |
) | |
else: | |
raise ValueError("[Error] Not supported junction loss function.") | |
return w_junc, junc_loss_func | |
def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): | |
"""Get the heatmap loss function and weight.""" | |
heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {}) | |
# Get the heatmap loss weight | |
w_policy = heatmap_loss_cfg.get("policy", global_w_policy) | |
if w_policy == "static": | |
w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32) | |
elif w_policy == "dynamic": | |
w_heatmap = nn.Parameter( | |
torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), | |
requires_grad=True, | |
) | |
else: | |
raise ValueError("[Error] Unknown weighting policy for junction loss weight.") | |
# Get the corresponding heatmap loss based on the config | |
heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy") | |
if heatmap_loss_name == "cross_entropy": | |
# Get the heatmap class weight (always static) | |
heatmap_class_w = model_cfg.get("w_heatmap_class", 1.0) | |
class_weight = ( | |
torch.tensor(np.array([1.0, heatmap_class_w])).to(torch.float).to(device) | |
) | |
heatmap_loss_func = HeatmapLoss(class_weight=class_weight) | |
else: | |
raise ValueError("[Error] Not supported heatmap loss function.") | |
return w_heatmap, heatmap_loss_func | |
def get_descriptor_loss_and_weight(model_cfg, global_w_policy): | |
"""Get the descriptor loss function and weight.""" | |
descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {}) | |
# Get the descriptor loss weight | |
w_policy = descriptor_loss_cfg.get("policy", global_w_policy) | |
if w_policy == "static": | |
w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32) | |
elif w_policy == "dynamic": | |
w_descriptor = nn.Parameter( | |
torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True | |
) | |
else: | |
raise ValueError("[Error] Unknown weighting policy for descriptor loss weight.") | |
# Get the descriptor loss function | |
descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling") | |
if descriptor_loss_name == "regular_sampling": | |
descriptor_loss_func = TripletDescriptorLoss( | |
descriptor_loss_cfg["grid_size"], | |
descriptor_loss_cfg["dist_threshold"], | |
descriptor_loss_cfg["margin"], | |
) | |
else: | |
raise ValueError("[Error] Not supported descriptor loss function.") | |
return w_descriptor, descriptor_loss_func | |
def space_to_depth(input_tensor, grid_size): | |
"""PixelUnshuffle for pytorch.""" | |
N, C, H, W = input_tensor.size() | |
# (N, C, H//bs, bs, W//bs, bs) | |
x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size) | |
# (N, bs, bs, C, H//bs, W//bs) | |
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() | |
# (N, C*bs^2, H//bs, W//bs) | |
x = x.view(N, C * (grid_size**2), H // grid_size, W // grid_size) | |
return x | |
def junction_detection_loss( | |
junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True | |
): | |
"""Junction detection loss.""" | |
# Convert junc_map to channel tensor | |
junc_map = space_to_depth(junction_map, grid_size) | |
map_shape = junc_map.shape[-2:] | |
batch_size = junc_map.shape[0] | |
dust_bin_label = ( | |
torch.ones([batch_size, 1, map_shape[0], map_shape[1]]) | |
.to(junc_map.device) | |
.to(torch.int) | |
) | |
junc_map = torch.cat([junc_map * 2, dust_bin_label], dim=1) | |
labels = torch.argmax( | |
junc_map.to(torch.float) | |
+ torch.distributions.Uniform(0, 0.1) | |
.sample(junc_map.shape) | |
.to(junc_map.device), | |
dim=1, | |
) | |
# Also convert the valid mask to channel tensor | |
valid_mask = torch.ones(junction_map.shape) if valid_mask is None else valid_mask | |
valid_mask = space_to_depth(valid_mask, grid_size) | |
# Compute junction loss on the border patch or not | |
if keep_border: | |
valid_mask = ( | |
torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0 | |
) | |
else: | |
valid_mask = ( | |
torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) | |
>= grid_size * grid_size | |
) | |
# Compute the classification loss | |
loss_func = nn.CrossEntropyLoss(reduction="none") | |
# The loss still need NCHW format | |
loss = loss_func(input=junc_predictions, target=labels.to(torch.long)) | |
# Weighted sum by the valid mask | |
loss_ = torch.sum( | |
loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2] | |
) | |
loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1)) | |
return loss_final | |
def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None): | |
"""Heatmap prediction loss.""" | |
# Compute the classification loss on each pixel | |
if class_weight is None: | |
loss_func = nn.CrossEntropyLoss(reduction="none") | |
else: | |
loss_func = nn.CrossEntropyLoss(class_weight, reduction="none") | |
loss = loss_func( | |
input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1) | |
) | |
# Weighted sum by the valid mask | |
# Sum over H and W | |
loss_spatial_sum = torch.sum( | |
loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[1, 2] | |
) | |
valid_spatial_sum = torch.sum( | |
torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2] | |
) | |
# Mean to single scalar over batch dimension | |
loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum) | |
return loss | |
class JunctionDetectionLoss(nn.Module): | |
"""Junction detection loss.""" | |
def __init__(self, grid_size, keep_border): | |
super(JunctionDetectionLoss, self).__init__() | |
self.grid_size = grid_size | |
self.keep_border = keep_border | |
def forward(self, prediction, target, valid_mask=None): | |
return junction_detection_loss( | |
target, prediction, valid_mask, self.grid_size, self.keep_border | |
) | |
class HeatmapLoss(nn.Module): | |
"""Heatmap prediction loss.""" | |
def __init__(self, class_weight): | |
super(HeatmapLoss, self).__init__() | |
self.class_weight = class_weight | |
def forward(self, prediction, target, valid_mask=None): | |
return heatmap_loss(target, prediction, valid_mask, self.class_weight) | |
class RegularizationLoss(nn.Module): | |
"""Module for regularization loss.""" | |
def __init__(self): | |
super(RegularizationLoss, self).__init__() | |
self.name = "regularization_loss" | |
self.loss_init = torch.zeros([]) | |
def forward(self, loss_weights): | |
# Place it to the same device | |
loss = self.loss_init.to(loss_weights["w_junc"].device) | |
for _, val in loss_weights.items(): | |
if isinstance(val, nn.Parameter): | |
loss += val | |
return loss | |
def triplet_loss( | |
desc_pred1, | |
desc_pred2, | |
points1, | |
points2, | |
line_indices, | |
epoch, | |
grid_size=8, | |
dist_threshold=8, | |
init_dist_threshold=64, | |
margin=1, | |
): | |
"""Regular triplet loss for descriptor learning.""" | |
b_size, _, Hc, Wc = desc_pred1.size() | |
img_size = (Hc * grid_size, Wc * grid_size) | |
device = desc_pred1.device | |
# Extract valid keypoints | |
n_points = line_indices.size()[1] | |
valid_points = line_indices.bool().flatten() | |
n_correct_points = torch.sum(valid_points).item() | |
if n_correct_points == 0: | |
return torch.tensor(0.0, dtype=torch.float, device=device) | |
# Check which keypoints are too close to be matched | |
# dist_threshold is decreased at each epoch for easier training | |
dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1)) | |
dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold) | |
# Additionally ban negative mining along the same line | |
common_line_mask = get_common_line_mask(line_indices, valid_points) | |
dist_mask = dist_mask | common_line_mask | |
# Convert the keypoints to a grid suitable for interpolation | |
grid1 = keypoints_to_grid(points1, img_size) | |
grid2 = keypoints_to_grid(points2, img_size) | |
# Extract the descriptors | |
desc1 = ( | |
F.grid_sample(desc_pred1, grid1) | |
.permute(0, 2, 3, 1) | |
.reshape(b_size * n_points, -1)[valid_points] | |
) | |
desc1 = F.normalize(desc1, dim=1) | |
desc2 = ( | |
F.grid_sample(desc_pred2, grid2) | |
.permute(0, 2, 3, 1) | |
.reshape(b_size * n_points, -1)[valid_points] | |
) | |
desc2 = F.normalize(desc2, dim=1) | |
desc_dists = 2 - 2 * (desc1 @ desc2.t()) | |
# Positive distance loss | |
pos_dist = torch.diag(desc_dists) | |
# Negative distance loss | |
max_dist = torch.tensor(4.0, dtype=torch.float, device=device) | |
desc_dists[ | |
torch.arange(n_correct_points, dtype=torch.long), | |
torch.arange(n_correct_points, dtype=torch.long), | |
] = max_dist | |
desc_dists[dist_mask] = max_dist | |
neg_dist = torch.min( | |
torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0] | |
) | |
triplet_loss = F.relu(margin + pos_dist - neg_dist) | |
return triplet_loss, grid1, grid2, valid_points | |
class TripletDescriptorLoss(nn.Module): | |
"""Triplet descriptor loss.""" | |
def __init__(self, grid_size, dist_threshold, margin): | |
super(TripletDescriptorLoss, self).__init__() | |
self.grid_size = grid_size | |
self.init_dist_threshold = 64 | |
self.dist_threshold = dist_threshold | |
self.margin = margin | |
def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch): | |
return self.descriptor_loss( | |
desc_pred1, desc_pred2, points1, points2, line_indices, epoch | |
) | |
# The descriptor loss based on regularly sampled points along the lines | |
def descriptor_loss( | |
self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch | |
): | |
return torch.mean( | |
triplet_loss( | |
desc_pred1, | |
desc_pred2, | |
points1, | |
points2, | |
line_indices, | |
epoch, | |
self.grid_size, | |
self.dist_threshold, | |
self.init_dist_threshold, | |
self.margin, | |
)[0] | |
) | |
class TotalLoss(nn.Module): | |
"""Total loss summing junction, heatma, descriptor | |
and regularization losses.""" | |
def __init__(self, loss_funcs, loss_weights, weighting_policy): | |
super(TotalLoss, self).__init__() | |
# Whether we need to compute the descriptor loss | |
self.compute_descriptors = "descriptor_loss" in loss_funcs.keys() | |
self.loss_funcs = loss_funcs | |
self.loss_weights = loss_weights | |
self.weighting_policy = weighting_policy | |
# Always add regularization loss (it will return zero if not used) | |
self.loss_funcs["reg_loss"] = RegularizationLoss().cuda() | |
def forward( | |
self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None | |
): | |
"""Detection only loss.""" | |
# Compute the junction loss | |
junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask) | |
# Compute the heatmap loss | |
heatmap_loss = self.loss_funcs["heatmap_loss"]( | |
heatmap_pred, heatmap_target, valid_mask | |
) | |
# Compute the total loss. | |
if self.weighting_policy == "dynamic": | |
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) | |
total_loss = ( | |
junc_loss * torch.exp(-self.loss_weights["w_junc"]) | |
+ heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) | |
+ reg_loss | |
) | |
return { | |
"total_loss": total_loss, | |
"junc_loss": junc_loss, | |
"heatmap_loss": heatmap_loss, | |
"reg_loss": reg_loss, | |
"w_junc": torch.exp(-self.loss_weights["w_junc"]).item(), | |
"w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(), | |
} | |
elif self.weighting_policy == "static": | |
total_loss = ( | |
junc_loss * self.loss_weights["w_junc"] | |
+ heatmap_loss * self.loss_weights["w_heatmap"] | |
) | |
return { | |
"total_loss": total_loss, | |
"junc_loss": junc_loss, | |
"heatmap_loss": heatmap_loss, | |
} | |
else: | |
raise ValueError("[Error] Unknown weighting policy.") | |
def forward_descriptors( | |
self, | |
junc_map_pred1, | |
junc_map_pred2, | |
junc_map_target1, | |
junc_map_target2, | |
heatmap_pred1, | |
heatmap_pred2, | |
heatmap_target1, | |
heatmap_target2, | |
line_points1, | |
line_points2, | |
line_indices, | |
desc_pred1, | |
desc_pred2, | |
epoch, | |
valid_mask1=None, | |
valid_mask2=None, | |
): | |
"""Loss for detection + description.""" | |
# Compute junction loss | |
junc_loss = self.loss_funcs["junc_loss"]( | |
torch.cat([junc_map_pred1, junc_map_pred2], dim=0), | |
torch.cat([junc_map_target1, junc_map_target2], dim=0), | |
torch.cat([valid_mask1, valid_mask2], dim=0), | |
) | |
# Get junction loss weight (dynamic or not) | |
if isinstance(self.loss_weights["w_junc"], nn.Parameter): | |
w_junc = torch.exp(-self.loss_weights["w_junc"]) | |
else: | |
w_junc = self.loss_weights["w_junc"] | |
# Compute heatmap loss | |
heatmap_loss = self.loss_funcs["heatmap_loss"]( | |
torch.cat([heatmap_pred1, heatmap_pred2], dim=0), | |
torch.cat([heatmap_target1, heatmap_target2], dim=0), | |
torch.cat([valid_mask1, valid_mask2], dim=0), | |
) | |
# Get heatmap loss weight (dynamic or not) | |
if isinstance(self.loss_weights["w_heatmap"], nn.Parameter): | |
w_heatmap = torch.exp(-self.loss_weights["w_heatmap"]) | |
else: | |
w_heatmap = self.loss_weights["w_heatmap"] | |
# Compute the descriptor loss | |
descriptor_loss = self.loss_funcs["descriptor_loss"]( | |
desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch | |
) | |
# Get descriptor loss weight (dynamic or not) | |
if isinstance(self.loss_weights["w_desc"], nn.Parameter): | |
w_descriptor = torch.exp(-self.loss_weights["w_desc"]) | |
else: | |
w_descriptor = self.loss_weights["w_desc"] | |
# Update the total loss | |
total_loss = ( | |
junc_loss * w_junc | |
+ heatmap_loss * w_heatmap | |
+ descriptor_loss * w_descriptor | |
) | |
outputs = { | |
"junc_loss": junc_loss, | |
"heatmap_loss": heatmap_loss, | |
"w_junc": w_junc.item() if isinstance(w_junc, nn.Parameter) else w_junc, | |
"w_heatmap": w_heatmap.item() | |
if isinstance(w_heatmap, nn.Parameter) | |
else w_heatmap, | |
"descriptor_loss": descriptor_loss, | |
"w_desc": w_descriptor.item() | |
if isinstance(w_descriptor, nn.Parameter) | |
else w_descriptor, | |
} | |
# Compute the regularization loss | |
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) | |
total_loss += reg_loss | |
outputs.update({"reg_loss": reg_loss, "total_loss": total_loss}) | |
return outputs | |