chendl's picture
Add application file
0b7b08a
raw
history blame
4.24 kB
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
from torch.autograd import Variable
from torchvision.ops import box_iou
class GraphConvolution(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features, bias=True, skip=True):
super(GraphConvolution, self).__init__()
self.skip = skip
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, adj):
# TODO make fc more efficient via "pack_padded_sequence"
# import ipdb; ipdb.set_trace()
support = torch.bmm(input, self.weight.unsqueeze(
0).expand(input.shape[0], -1, -1))
output = torch.bmm(adj, support)
#output = SparseMM(adj)(support)
if self.bias is not None:
output += self.bias.unsqueeze(0).expand(input.shape[0], -1, -1)
if self.skip:
output += support
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
class GCN_sim(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out, dropout, num_layers):
super(GCN_sim, self).__init__()
assert num_layers >= 1
self.fc_k = nn.Linear(dim_in, dim_hidden)
self.fc_q = nn.Linear(dim_in, dim_hidden)
dim_hidden = dim_out if num_layers == 1 else dim_hidden
self.gcs = nn.ModuleList([
GraphConvolution(dim_in, dim_hidden)
])
for i in range(num_layers - 1):
dim_tmp = dim_out if i == num_layers-2 else dim_hidden
self.gcs.append(GraphConvolution(dim_hidden, dim_tmp))
self.dropout = dropout
def construct_graph(self, x, length):
# TODO make fc more efficient via "pack_padded_sequence"
emb_k = self.fc_k(x)
emb_q = self.fc_q(x)
s = torch.bmm(emb_k, emb_q.transpose(1, 2))
s_mask = s.data.new(*s.size()).fill_(1).bool() # [B, T1, T2]
# Init similarity mask using lengths
for i, (l_1, l_2) in enumerate(zip(length, length)):
s_mask[i][:l_1, :l_2] = 0
s_mask = Variable(s_mask)
s.data.masked_fill_(s_mask.data, -float("inf"))
a_weight = F.softmax(s, dim=2) # [B, t1, t2]
# remove nan from softmax on -inf
a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)
return a_weight
def forward(self, x, length):
adj_sim = self.construct_graph(x, length)
for gc in self.gcs:
x = F.relu(gc(x, adj_sim))
x = F.dropout(x, self.dropout, training=self.training)
return x
class GCN(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out, dropout, mode, skip, num_layers, ST_n_next=None):
super(GCN, self).__init__()
assert len(mode) != 0
self.mode = mode
self.skip = skip
if "GCN_sim" in mode:
self.GCN_sim = GCN_sim(
dim_in, dim_hidden, dim_out, dropout, num_layers)
def forward(self, x, length):
out = []
if "GCN_sim" in self.mode:
out.append(self.GCN_sim(x, length))
out = sum(out)
if self.skip:
out += x
return out
if __name__ == '__main__':
model = GCN(512, 128, 512, 0.5, mode=[
"GCN_sim"], skip=True, num_layers=3, ST_n_next=3)
bs, T, N = 10, 5, 10
n_node = T*N
input = torch.rand(bs, n_node, 512)
length = torch.ones((bs))
length = length.type(torch.IntTensor)
bboxes = torch.rand((bs, 5, 10, 4))
output = model(input, length)