File size: 1,210 Bytes
88b0dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
"""
@Date: 2021/08/12
@description:
"""
import torch
import torch.nn as nn
class LEDLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss = nn.L1Loss()
def forward(self, gt, dt):
camera_height = 1.6
gt_depth = gt['depth'] * camera_height
dt_ceil_depth = dt['ceil_depth'] * camera_height * gt['ratio']
dt_floor_depth = dt['depth'] * camera_height
ceil_loss = self.loss(gt_depth, dt_ceil_depth)
floor_loss = self.loss(gt_depth, dt_floor_depth)
loss = floor_loss + ceil_loss
return loss
if __name__ == '__main__':
import numpy as np
from dataset.mp3d_dataset import MP3DDataset
mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train')
gt = mp3d_dataset.__getitem__(0)
gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis]) # batch size is 1
dummy_dt = {
'depth': gt['depth'].clone(),
'ceil_depth': gt['depth'] / gt['ratio']
}
# dummy_dt['depth'][..., :20] *= 3 # some different
led_loss = LEDLoss()
loss = led_loss(gt, dummy_dt)
print(loss)
|