Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import spiga.models.gnn.pose_proj as pproj | |
from spiga.models.cnn.cnn_multitask import MultitaskCNN | |
from spiga.models.gnn.step_regressor import StepRegressor, RelativePositionEncoder | |
class SPIGA(nn.Module): | |
def __init__(self, num_landmarks=98, num_edges=15, steps=3, **kwargs): | |
super(SPIGA, self).__init__() | |
# Model parameters | |
self.steps = steps # Cascaded regressors | |
self.embedded_dim = 512 # GAT input channel | |
self.nstack = 4 # Number of stacked GATs per step | |
self.kwindow = 7 # Output cropped window dimension (kernel) | |
self.swindow = 0.25 # Scale of the cropped window at first step (Dft. 25% w.r.t the input featuremap) | |
self.offset_ratio = [self.swindow/(2**step)/2 for step in range(self.steps)] | |
# CNN parameters | |
self.num_landmarks = num_landmarks | |
self.num_edges = num_edges | |
# Initialize backbone | |
self.visual_cnn = MultitaskCNN(num_landmarks=self.num_landmarks, num_edges=self.num_edges) | |
# Features dimensions | |
self.img_res = self.visual_cnn.img_res | |
self.visual_res = self.visual_cnn.out_res | |
self.visual_dim = self.visual_cnn.ch_dim | |
# Initialize Pose head | |
self.channels_pose = 6 | |
self.pose_fc = nn.Linear(self.visual_cnn.ch_dim, self.channels_pose) | |
# Initialize feature extractors: | |
# Relative positional encoder | |
shape_dim = 2 * (self.num_landmarks - 1) | |
shape_encoder = [] | |
for step in range(self.steps): | |
shape_encoder.append(RelativePositionEncoder(shape_dim, self.embedded_dim, [256, 256])) | |
self.shape_encoder = nn.ModuleList(shape_encoder) | |
# Diagonal mask used to compute relative positions | |
diagonal_mask = (torch.ones(self.num_landmarks, self.num_landmarks) - torch.eye(self.num_landmarks)).type(torch.bool) | |
self.diagonal_mask = nn.parameter.Parameter(diagonal_mask, requires_grad=False) | |
# Visual feature extractor | |
conv_window = [] | |
theta_S = [] | |
for step in range(self.steps): | |
# S matrix per step | |
WH = self.visual_res # Width/height of ftmap | |
Wout = self.swindow / (2 ** step) * WH # Width/height of the window | |
K = self.kwindow # Kernel or resolution of the window | |
scale = K / WH * (Wout - 1) / (K - 1) # Scale of the affine transformation | |
# Rescale matrix S | |
theta_S_stp = torch.tensor([[scale, 0], [0, scale]]) | |
theta_S.append(nn.parameter.Parameter(theta_S_stp, requires_grad=False)) | |
# Convolutional to embedded to BxLxCx1x1 | |
conv_window.append(nn.Conv2d(self.visual_dim, self.embedded_dim, self.kwindow)) | |
self.theta_S = nn.ParameterList(theta_S) | |
self.conv_window = nn.ModuleList(conv_window) | |
# Initialize GAT modules | |
self.gcn = nn.ModuleList([StepRegressor(self.embedded_dim, 256, self.nstack) for i in range(self.steps)]) | |
def forward(self, data): | |
# Inputs: Visual features and points projections | |
pts_proj, features = self.backbone_forward(data) | |
# Visual field | |
visual_field = features['VisualField'][-1] | |
# Params compute only once | |
gat_prob = [] | |
features['Landmarks'] = [] | |
for step in range(self.steps): | |
# Features generation | |
embedded_ft = self.extract_embedded(pts_proj, visual_field, step) | |
# GAT inference | |
offset, gat_prob = self.gcn[step](embedded_ft, gat_prob) | |
offset = F.hardtanh(offset) | |
# Update coordinates | |
pts_proj = pts_proj + self.offset_ratio[step] * offset | |
features['Landmarks'].append(pts_proj.clone()) | |
features['GATProb'] = gat_prob | |
return features | |
def backbone_forward(self, data): | |
# Inputs: Image and model3D | |
imgs = data[0] | |
model3d = data[1] | |
cam_matrix = data[2] | |
# HourGlass Forward | |
features = self.visual_cnn(imgs) | |
# Head pose estimation | |
pose_raw = features['HGcore'][-1] | |
B, L, _, _ = pose_raw.shape | |
pose = pose_raw.reshape(B, L) | |
pose = self.pose_fc(pose) | |
features['Pose'] = pose.clone() | |
# Project model 3D | |
euler = pose[:, 0:3] | |
trl = pose[:, 3:] | |
rot = pproj.euler_to_rotation_matrix(euler) | |
pts_proj = pproj.projectPoints(model3d, rot, trl, cam_matrix) | |
pts_proj = pts_proj / self.visual_res | |
return pts_proj, features | |
def extract_embedded(self, pts_proj, receptive_field, step): | |
# Visual features | |
visual_ft = self.extract_visual_embedded(pts_proj, receptive_field, step) | |
# Shape features | |
shape_ft = self.calculate_distances(pts_proj) | |
shape_ft = self.shape_encoder[step](shape_ft) | |
# Addition | |
embedded_ft = visual_ft + shape_ft | |
return embedded_ft | |
def extract_visual_embedded(self, pts_proj, receptive_field, step): | |
# Affine matrix generation | |
B, L, _ = pts_proj.shape # Pts_proj range:[0,1] | |
centers = pts_proj + 0.5 / self.visual_res # BxLx2 | |
centers = centers.reshape(B * L, 2) # B*Lx2 | |
theta_trl = (-1 + centers * 2).unsqueeze(-1) # BxLx2x1 | |
theta_s = self.theta_S[step] # 2x2 | |
theta_s = theta_s.repeat(B * L, 1, 1) # B*Lx2x2 | |
theta = torch.cat((theta_s, theta_trl), -1) # B*Lx2x3 | |
# Generate crop grid | |
B, C, _, _ = receptive_field.shape | |
grid = torch.nn.functional.affine_grid(theta, (B * L, C, self.kwindow, self.kwindow)) | |
grid = grid.reshape(B, L, self.kwindow, self.kwindow, 2) | |
grid = grid.reshape(B, L, self.kwindow * self.kwindow, 2) | |
# Crop windows | |
crops = torch.nn.functional.grid_sample(receptive_field, grid, padding_mode="border") # BxCxLxK*K | |
crops = crops.transpose(1, 2) # BxLxCxK*K | |
crops = crops.reshape(B * L, C, self.kwindow, self.kwindow) | |
# Flatten features | |
visual_ft = self.conv_window[step](crops) | |
_, Cout, _, _ = visual_ft.shape | |
visual_ft = visual_ft.reshape(B, L, Cout) | |
return visual_ft | |
def calculate_distances(self, pts_proj): | |
B, L, _ = pts_proj.shape # BxLx2 | |
pts_a = pts_proj.unsqueeze(-2).repeat(1, 1, L, 1) | |
pts_b = pts_a.transpose(1, 2) | |
dist = pts_a - pts_b | |
dist_wo_self = dist[:, self.diagonal_mask, :].reshape(B, L, -1) | |
return dist_wo_self | |