File size: 1,588 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
""" 
@Date: 2021/08/12
@description: For HorizonNet, using latitudes to calculate loss.
"""
import torch
import torch.nn as nn
from utils.conversion import depth2xyz, xyz2lonlat


class BoundaryLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.L1Loss()

    def forward(self, gt, dt):
        gt_floor_xyz = depth2xyz(gt['depth'])
        gt_ceil_xyz = gt_floor_xyz.clone()
        gt_ceil_xyz[..., 1] = -gt['ratio']

        gt_floor_boundary = xyz2lonlat(gt_floor_xyz)[..., -1:]
        gt_ceil_boundary = xyz2lonlat(gt_ceil_xyz)[..., -1:]

        gt_boundary = torch.cat([gt_floor_boundary, gt_ceil_boundary], dim=-1).permute(0, 2, 1)
        dt_boundary = dt['boundary']

        loss = self.loss(gt_boundary, dt_boundary)
        return loss


if __name__ == '__main__':
    import numpy as np
    from dataset.mp3d_dataset import MP3DDataset

    mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train')
    gt = mp3d_dataset.__getitem__(0)

    gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis])  # batch size is 1
    gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis])  # batch size is 1

    dummy_dt = {
        'depth': gt['depth'].clone(),
        'boundary': torch.cat([
            xyz2lonlat(depth2xyz(gt['depth']))[..., -1:],
            xyz2lonlat(depth2xyz(gt['depth'], plan_y=-gt['ratio']))[..., -1:]
            ], dim=-1).permute(0, 2, 1)
    }
    # dummy_dt['boundary'][:, :, :20] /= 1.2  # some different

    boundary_loss = BoundaryLoss()
    loss = boundary_loss(gt, dummy_dt)
    print(loss)