watchtowerss's picture
track-anything --version 1
4d1ebf3
raw
history blame
412 Bytes
import torch
import torch.nn.functional as F
# Soft aggregation from STM
def aggregate(prob, dim, return_logits=False):
new_prob = torch.cat([
torch.prod(1-prob, dim=dim, keepdim=True),
prob
], dim).clamp(1e-7, 1-1e-7)
logits = torch.log((new_prob /(1-new_prob)))
prob = F.softmax(logits, dim=dim)
if return_logits:
return logits, prob
else:
return prob