|
""" |
|
@Date: 2021/08/12 |
|
@description: |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class LEDLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.loss = nn.L1Loss() |
|
|
|
def forward(self, gt, dt): |
|
camera_height = 1.6 |
|
gt_depth = gt['depth'] * camera_height |
|
|
|
dt_ceil_depth = dt['ceil_depth'] * camera_height * gt['ratio'] |
|
dt_floor_depth = dt['depth'] * camera_height |
|
|
|
ceil_loss = self.loss(gt_depth, dt_ceil_depth) |
|
floor_loss = self.loss(gt_depth, dt_floor_depth) |
|
|
|
loss = floor_loss + ceil_loss |
|
|
|
return loss |
|
|
|
|
|
if __name__ == '__main__': |
|
import numpy as np |
|
from dataset.mp3d_dataset import MP3DDataset |
|
|
|
mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train') |
|
gt = mp3d_dataset.__getitem__(0) |
|
|
|
gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) |
|
gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis]) |
|
|
|
dummy_dt = { |
|
'depth': gt['depth'].clone(), |
|
'ceil_depth': gt['depth'] / gt['ratio'] |
|
} |
|
|
|
|
|
led_loss = LEDLoss() |
|
loss = led_loss(gt, dummy_dt) |
|
print(loss) |
|
|