File size: 1,886 Bytes
7ca9b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import torch.nn.functional as F
def kl_loss(code):
return torch.mean(torch.pow(code, 2))
def pairwise_cosine_similarity(seqs_i, seqs_j):
# seqs_i, seqs_j: [batch, statics, channel]
n_statics = seqs_i.size(1)
seqs_i_exp = seqs_i.unsqueeze(2).repeat(1, 1, n_statics, 1)
seqs_j_exp = seqs_j.unsqueeze(1).repeat(1, n_statics, 1, 1)
return F.cosine_similarity(seqs_i_exp, seqs_j_exp, dim=3)
def temporal_pairwise_cosine_similarity(seqs_i, seqs_j):
# seqs_i, seqs_j: [batch, channel, time]
seq_len = seqs_i.size(2)
seqs_i_exp = seqs_i.unsqueeze(3).repeat(1, 1, 1, seq_len)
seqs_j_exp = seqs_j.unsqueeze(2).repeat(1, 1, seq_len, 1)
return F.cosine_similarity(seqs_i_exp, seqs_j_exp, dim=1)
def consecutive_cosine_similarity(seqs):
# seqs: [batch, channel, time]
seqs_roll = seqs.roll(shifts=1, dim=2)[1:]
seqs = seqs[:-1]
return F.cosine_similarity(seqs, seqs_roll)
def triplet_margin_loss(seqs_a, seqs_b, neg_range=(0.0, 0.5), margin=0.2):
# seqs_a, seqs_b: [batch, channel, time]
neg_start, neg_end = neg_range
batch_size, _, seq_len = seqs_a.size()
n_neg_all = seq_len ** 2
n_neg = int(round(neg_end * n_neg_all))
n_neg_discard = int(round(neg_start * n_neg_all))
batch_size, _, seq_len = seqs_a.size()
sim_aa = temporal_pairwise_cosine_similarity(seqs_a, seqs_a)
sim_bb = temporal_pairwise_cosine_similarity(seqs_b, seqs_b)
sim_ab = temporal_pairwise_cosine_similarity(seqs_a, seqs_b)
sim_ba = sim_ab.transpose(1, 2)
diff_ab = (sim_ab - sim_aa).reshape(batch_size, -1)
diff_ba = (sim_ba - sim_bb).reshape(batch_size, -1)
diff = torch.cat([diff_ab, diff_ba], dim=0)
diff, _ = diff.topk(n_neg, dim=-1, sorted=True)
diff = diff[:, n_neg_discard:]
loss = diff + margin
loss = loss.clamp(min=0.)
loss = loss.mean()
return loss |