File size: 6,704 Bytes
9390e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import torch.nn as nn
import torch.nn.functional as F

import spiga.models.gnn.pose_proj as pproj
from spiga.models.cnn.cnn_multitask import MultitaskCNN
from spiga.models.gnn.step_regressor import StepRegressor, RelativePositionEncoder


class SPIGA(nn.Module):
    def __init__(self, num_landmarks=98, num_edges=15, steps=3, **kwargs):

        super(SPIGA, self).__init__()

        # Model parameters
        self.steps = steps          # Cascaded regressors
        self.embedded_dim = 512     # GAT input channel
        self.nstack = 4             # Number of stacked GATs per step
        self.kwindow = 7            # Output cropped window dimension (kernel)
        self.swindow = 0.25         # Scale of the cropped window at first step (Dft. 25% w.r.t the input featuremap)
        self.offset_ratio = [self.swindow/(2**step)/2 for step in range(self.steps)]

        # CNN parameters
        self.num_landmarks = num_landmarks
        self.num_edges = num_edges

        # Initialize backbone
        self.visual_cnn = MultitaskCNN(num_landmarks=self.num_landmarks, num_edges=self.num_edges)
        # Features dimensions
        self.img_res = self.visual_cnn.img_res
        self.visual_res = self.visual_cnn.out_res
        self.visual_dim = self.visual_cnn.ch_dim

        # Initialize Pose head
        self.channels_pose = 6
        self.pose_fc = nn.Linear(self.visual_cnn.ch_dim, self.channels_pose)

        # Initialize feature extractors:
        # Relative positional encoder
        shape_dim = 2 * (self.num_landmarks - 1)
        shape_encoder = []
        for step in range(self.steps):
            shape_encoder.append(RelativePositionEncoder(shape_dim, self.embedded_dim, [256, 256]))
        self.shape_encoder = nn.ModuleList(shape_encoder)
        # Diagonal mask used to compute relative positions
        diagonal_mask = (torch.ones(self.num_landmarks, self.num_landmarks) - torch.eye(self.num_landmarks)).type(torch.bool)
        self.diagonal_mask = nn.parameter.Parameter(diagonal_mask, requires_grad=False)

        # Visual feature extractor
        conv_window = []
        theta_S = []
        for step in range(self.steps):
            # S matrix per step
            WH = self.visual_res                                  # Width/height of ftmap
            Wout = self.swindow / (2 ** step) * WH                # Width/height of the window
            K = self.kwindow                                      # Kernel or resolution of the window
            scale = K / WH * (Wout - 1) / (K - 1)                 # Scale of the affine transformation
            # Rescale matrix S
            theta_S_stp = torch.tensor([[scale, 0], [0, scale]])
            theta_S.append(nn.parameter.Parameter(theta_S_stp, requires_grad=False))

            # Convolutional to embedded to BxLxCx1x1
            conv_window.append(nn.Conv2d(self.visual_dim, self.embedded_dim, self.kwindow))

        self.theta_S = nn.ParameterList(theta_S)
        self.conv_window = nn.ModuleList(conv_window)

        # Initialize GAT modules
        self.gcn = nn.ModuleList([StepRegressor(self.embedded_dim, 256, self.nstack) for i in range(self.steps)])

    def forward(self, data):
        # Inputs: Visual features and points projections
        pts_proj, features = self.backbone_forward(data)
        # Visual field
        visual_field = features['VisualField'][-1]

        # Params compute only once
        gat_prob = []
        features['Landmarks'] = []
        for step in range(self.steps):
            # Features generation
            embedded_ft = self.extract_embedded(pts_proj, visual_field, step)

            # GAT inference
            offset, gat_prob = self.gcn[step](embedded_ft, gat_prob)
            offset = F.hardtanh(offset)

            # Update coordinates
            pts_proj = pts_proj + self.offset_ratio[step] * offset
            features['Landmarks'].append(pts_proj.clone())

        features['GATProb'] = gat_prob
        return features

    def backbone_forward(self, data):
        # Inputs: Image and model3D
        imgs = data[0]
        model3d = data[1]
        cam_matrix = data[2]

        # HourGlass Forward
        features = self.visual_cnn(imgs)

        # Head pose estimation
        pose_raw = features['HGcore'][-1]
        B, L, _, _ = pose_raw.shape
        pose = pose_raw.reshape(B, L)
        pose = self.pose_fc(pose)
        features['Pose'] = pose.clone()

        # Project model 3D
        euler = pose[:, 0:3]
        trl = pose[:, 3:]
        rot = pproj.euler_to_rotation_matrix(euler)
        pts_proj = pproj.projectPoints(model3d, rot, trl, cam_matrix)
        pts_proj = pts_proj / self.visual_res

        return pts_proj, features

    def extract_embedded(self, pts_proj, receptive_field, step):
        # Visual features
        visual_ft = self.extract_visual_embedded(pts_proj, receptive_field, step)
        # Shape features
        shape_ft = self.calculate_distances(pts_proj)
        shape_ft = self.shape_encoder[step](shape_ft)
        # Addition
        embedded_ft = visual_ft + shape_ft
        return embedded_ft

    def extract_visual_embedded(self, pts_proj, receptive_field, step):
        # Affine matrix generation
        B, L, _ = pts_proj.shape  # Pts_proj range:[0,1]
        centers = pts_proj + 0.5 / self.visual_res  # BxLx2
        centers = centers.reshape(B * L, 2)  # B*Lx2
        theta_trl = (-1 + centers * 2).unsqueeze(-1)  # BxLx2x1
        theta_s = self.theta_S[step]  # 2x2
        theta_s = theta_s.repeat(B * L, 1, 1)  # B*Lx2x2
        theta = torch.cat((theta_s, theta_trl), -1)  # B*Lx2x3

        # Generate crop grid
        B, C, _, _ = receptive_field.shape
        grid = torch.nn.functional.affine_grid(theta, (B * L, C, self.kwindow, self.kwindow))
        grid = grid.reshape(B, L, self.kwindow, self.kwindow, 2)
        grid = grid.reshape(B, L, self.kwindow * self.kwindow, 2)

        # Crop windows
        crops = torch.nn.functional.grid_sample(receptive_field, grid, padding_mode="border")  # BxCxLxK*K
        crops = crops.transpose(1, 2)  # BxLxCxK*K
        crops = crops.reshape(B * L, C, self.kwindow, self.kwindow)

        # Flatten features
        visual_ft = self.conv_window[step](crops)
        _, Cout, _, _ = visual_ft.shape
        visual_ft = visual_ft.reshape(B, L, Cout)

        return visual_ft

    def calculate_distances(self, pts_proj):
        B, L, _ = pts_proj.shape    # BxLx2
        pts_a = pts_proj.unsqueeze(-2).repeat(1, 1, L, 1)
        pts_b = pts_a.transpose(1, 2)
        dist = pts_a - pts_b
        dist_wo_self = dist[:, self.diagonal_mask, :].reshape(B, L, -1)
        return dist_wo_self