Spaces:
Runtime error
Runtime error
File size: 12,126 Bytes
29a229f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
"""
Code adapted from: https://github.com/akanazawa/hmr/blob/master/src/benchmark/eval_util.py
"""
import torch
import numpy as np
from typing import Optional, Dict, List, Tuple
def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
"""
Computes a similarity transform (sR, t) in a batched way that takes
a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3),
where R is a 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
Args:
S1 (torch.Tensor): First set of points of shape (B, N, 3).
S2 (torch.Tensor): Second set of points of shape (B, N, 3).
Returns:
(torch.Tensor): The first set of points after applying the similarity transformation.
"""
batch_size = S1.shape[0]
S1 = S1.permute(0, 2, 1)
S2 = S2.permute(0, 2, 1)
# 1. Remove mean.
mu1 = S1.mean(dim=2, keepdim=True)
mu2 = S2.mean(dim=2, keepdim=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = (X1**2).sum(dim=(1,2))
# 3. The outer product of X1 and X2.
K = torch.matmul(X1, X2.permute(0, 2, 1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
U, s, V = torch.svd(K)
Vh = V.permute(0, 2, 1)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1)
Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U, Vh)))
# Construct R.
R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1))
# 5. Recover scale.
trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1)
scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1)
# 6. Recover translation.
t = mu2 - scale*torch.matmul(R, mu1)
# 7. Error:
S1_hat = scale*torch.matmul(R, S1) + t
return S1_hat.permute(0, 2, 1)
def reconstruction_error(S1, S2) -> np.array:
"""
Computes the mean Euclidean distance of 2 set of points S1, S2 after performing Procrustes alignment.
Args:
S1 (torch.Tensor): First set of points of shape (B, N, 3).
S2 (torch.Tensor): Second set of points of shape (B, N, 3).
Returns:
(np.array): Reconstruction error.
"""
S1_hat = compute_similarity_transform(S1, S2)
re = torch.sqrt( ((S1_hat - S2)** 2).sum(dim=-1)).mean(dim=-1)
return re
def eval_pose(pred_joints, gt_joints) -> Tuple[np.array, np.array]:
"""
Compute joint errors in mm before and after Procrustes alignment.
Args:
pred_joints (torch.Tensor): Predicted 3D joints of shape (B, N, 3).
gt_joints (torch.Tensor): Ground truth 3D joints of shape (B, N, 3).
Returns:
Tuple[np.array, np.array]: Joint errors in mm before and after alignment.
"""
# Absolute error (MPJPE)
mpjpe = torch.sqrt(((pred_joints - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
# Reconstuction_error
r_error = reconstruction_error(pred_joints, gt_joints).cpu().numpy()
return 1000 * mpjpe, 1000 * r_error
class Evaluator:
def __init__(self,
dataset_length: int,
keypoint_list: List,
pelvis_ind: int,
metrics: List = ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re'],
pck_thresholds: Optional[List] = None):
"""
Class used for evaluating trained models on different 3D pose datasets.
Args:
dataset_length (int): Total dataset length.
keypoint_list [List]: List of keypoints used for evaluation.
pelvis_ind (int): Index of pelvis keypoint; used for aligning the predictions and ground truth.
metrics [List]: List of evaluation metrics to record.
"""
self.dataset_length = dataset_length
self.keypoint_list = keypoint_list
self.pelvis_ind = pelvis_ind
self.metrics = metrics
for metric in self.metrics:
setattr(self, metric, np.zeros((dataset_length,)))
self.counter = 0
if pck_thresholds is None:
self.pck_evaluator = None
else:
self.pck_evaluator = EvaluatorPCK(pck_thresholds)
def log(self):
"""
Print current evaluation metrics
"""
if self.counter == 0:
print('Evaluation has not started')
return
print(f'{self.counter} / {self.dataset_length} samples')
if self.pck_evaluator is not None:
self.pck_evaluator.log()
for metric in self.metrics:
if metric in ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re']:
unit = 'mm'
else:
unit = ''
print(f'{metric}: {getattr(self, metric)[:self.counter].mean()} {unit}')
print('***')
def get_metrics_dict(self) -> Dict:
"""
Returns:
Dict: Dictionary of evaluation metrics.
"""
d1 = {metric: getattr(self, metric)[:self.counter].mean() for metric in self.metrics}
if self.pck_evaluator is not None:
d2 = self.pck_evaluator.get_metrics_dict()
d1.update(d2)
return d1
def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
"""
Evaluate current batch.
Args:
output (Dict): Regression output.
batch (Dict): Dictionary containing images and their corresponding annotations.
opt_output (Dict): Optimization output.
"""
if self.pck_evaluator is not None:
self.pck_evaluator(output, batch, opt_output)
pred_keypoints_3d = output['pred_keypoints_3d'].detach()
pred_keypoints_3d = pred_keypoints_3d[:,None,:,:]
batch_size = pred_keypoints_3d.shape[0]
num_samples = pred_keypoints_3d.shape[1]
gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1)
# Align predictions and ground truth such that the pelvis location is at the origin
pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]]
gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]]
# Compute joint errors
mpjpe, re = eval_pose(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3)[:, self.keypoint_list], gt_keypoints_3d.reshape(batch_size * num_samples, -1 ,3)[:, self.keypoint_list])
mpjpe = mpjpe.reshape(batch_size, num_samples)
re = re.reshape(batch_size, num_samples)
# Compute 2d keypoint errors
pred_keypoints_2d = output['pred_keypoints_2d'].detach()
pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
conf = gt_keypoints_2d[:, :, :, -1].clone()
kp_err = torch.nn.functional.mse_loss(
pred_keypoints_2d,
gt_keypoints_2d[:, :, :, :-1],
reduction='none'
).sum(dim=3)
kp_l2_loss = (conf * kp_err).mean(dim=2)
kp_l2_loss = kp_l2_loss.detach().cpu().numpy()
# Compute joint errors after optimization, if available.
if opt_output is not None:
opt_keypoints_3d = opt_output['model_joints']
opt_keypoints_3d -= opt_keypoints_3d[:, [self.pelvis_ind]]
opt_mpjpe, opt_re = eval_pose(opt_keypoints_3d[:, self.keypoint_list], gt_keypoints_3d[:, 0, self.keypoint_list])
# The 0-th sample always corresponds to the mode
if hasattr(self, 'mode_mpjpe'):
mode_mpjpe = mpjpe[:, 0]
self.mode_mpjpe[self.counter:self.counter+batch_size] = mode_mpjpe
if hasattr(self, 'mode_re'):
mode_re = re[:, 0]
self.mode_re[self.counter:self.counter+batch_size] = mode_re
if hasattr(self, 'mode_kpl2'):
mode_kpl2 = kp_l2_loss[:, 0]
self.mode_kpl2[self.counter:self.counter+batch_size] = mode_kpl2
if hasattr(self, 'min_mpjpe'):
min_mpjpe = mpjpe.min(axis=-1)
self.min_mpjpe[self.counter:self.counter+batch_size] = min_mpjpe
if hasattr(self, 'min_re'):
min_re = re.min(axis=-1)
self.min_re[self.counter:self.counter+batch_size] = min_re
if hasattr(self, 'min_kpl2'):
min_kpl2 = kp_l2_loss.min(axis=-1)
self.min_kpl2[self.counter:self.counter+batch_size] = min_kpl2
if hasattr(self, 'opt_mpjpe'):
self.opt_mpjpe[self.counter:self.counter+batch_size] = opt_mpjpe
if hasattr(self, 'opt_re'):
self.opt_re[self.counter:self.counter+batch_size] = opt_re
self.counter += batch_size
if hasattr(self, 'mode_mpjpe') and hasattr(self, 'mode_re'):
return {
'mode_mpjpe': mode_mpjpe,
'mode_re': mode_re,
}
else:
return {}
class EvaluatorPCK:
def __init__(self, thresholds: List = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5],):
"""
Class used for evaluating trained models on different 3D pose datasets.
Args:
thresholds [List]: List of PCK thresholds to evaluate.
metrics [List]: List of evaluation metrics to record.
"""
self.thresholds = thresholds
self.pred_kp_2d = []
self.gt_kp_2d = []
self.gt_conf_2d = []
self.counter = 0
def log(self):
"""
Print current evaluation metrics
"""
if self.counter == 0:
print('Evaluation has not started')
return
print(f'{self.counter} samples')
metrics_dict = self.get_metrics_dict()
for metric in metrics_dict:
print(f'{metric}: {metrics_dict[metric]}')
print('***')
def get_metrics_dict(self) -> Dict:
"""
Returns:
Dict: Dictionary of evaluation metrics.
"""
pcks = self.compute_pcks()
metrics = {}
for thr, (acc,avg_acc,cnt) in zip(self.thresholds, pcks):
metrics.update({f'kp{i}_pck_{thr}': float(a) for i, a in enumerate(acc) if a>=0})
metrics.update({f'kpAvg_pck_{thr}': float(avg_acc)})
return metrics
def compute_pcks(self):
pred_kp_2d = np.concatenate(self.pred_kp_2d, axis=0)
gt_kp_2d = np.concatenate(self.gt_kp_2d, axis=0)
gt_conf_2d = np.concatenate(self.gt_conf_2d, axis=0)
assert pred_kp_2d.shape == gt_kp_2d.shape
assert pred_kp_2d[..., 0].shape == gt_conf_2d.shape
assert pred_kp_2d.shape[1] == 1 # num_samples
from mmpose.core.evaluation import keypoint_pck_accuracy
pcks = [
keypoint_pck_accuracy(
pred_kp_2d[:, 0, :, :],
gt_kp_2d[:, 0, :, :],
gt_conf_2d[:, 0, :]>0.5,
thr=thr,
normalize = np.ones((len(pred_kp_2d),2)) # Already in [-0.5,0.5] range. No need to normalize
)
for thr in self.thresholds
]
return pcks
def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
"""
Evaluate current batch.
Args:
output (Dict): Regression output.
batch (Dict): Dictionary containing images and their corresponding annotations.
opt_output (Dict): Optimization output.
"""
pred_keypoints_2d = output['pred_keypoints_2d'].detach()
num_samples = 1
batch_size = pred_keypoints_2d.shape[0]
pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
self.pred_kp_2d.append(pred_keypoints_2d[:, :, :, :2].detach().cpu().numpy())
self.gt_conf_2d.append(gt_keypoints_2d[:, :, :, -1].detach().cpu().numpy())
self.gt_kp_2d.append(gt_keypoints_2d[:, :, :, :2].detach().cpu().numpy())
self.counter += batch_size
|