Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter | |
import math | |
from torch.autograd import Variable | |
from torchvision.ops import box_iou | |
class GraphConvolution(nn.Module): | |
""" | |
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 | |
""" | |
def __init__(self, in_features, out_features, bias=True, skip=True): | |
super(GraphConvolution, self).__init__() | |
self.skip = skip | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = Parameter(torch.Tensor(in_features, out_features)) | |
if bias: | |
self.bias = Parameter(torch.Tensor(out_features)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
stdv = 1. / math.sqrt(self.weight.size(1)) | |
self.weight.data.uniform_(-stdv, stdv) | |
if self.bias is not None: | |
self.bias.data.uniform_(-stdv, stdv) | |
def forward(self, input, adj): | |
# TODO make fc more efficient via "pack_padded_sequence" | |
# import ipdb; ipdb.set_trace() | |
support = torch.bmm(input, self.weight.unsqueeze( | |
0).expand(input.shape[0], -1, -1)) | |
output = torch.bmm(adj, support) | |
#output = SparseMM(adj)(support) | |
if self.bias is not None: | |
output += self.bias.unsqueeze(0).expand(input.shape[0], -1, -1) | |
if self.skip: | |
output += support | |
return output | |
def __repr__(self): | |
return self.__class__.__name__ + ' (' \ | |
+ str(self.in_features) + ' -> ' \ | |
+ str(self.out_features) + ')' | |
class GCN_sim(nn.Module): | |
def __init__(self, dim_in, dim_hidden, dim_out, dropout, num_layers): | |
super(GCN_sim, self).__init__() | |
assert num_layers >= 1 | |
self.fc_k = nn.Linear(dim_in, dim_hidden) | |
self.fc_q = nn.Linear(dim_in, dim_hidden) | |
dim_hidden = dim_out if num_layers == 1 else dim_hidden | |
self.gcs = nn.ModuleList([ | |
GraphConvolution(dim_in, dim_hidden) | |
]) | |
for i in range(num_layers - 1): | |
dim_tmp = dim_out if i == num_layers-2 else dim_hidden | |
self.gcs.append(GraphConvolution(dim_hidden, dim_tmp)) | |
self.dropout = dropout | |
def construct_graph(self, x, length): | |
# TODO make fc more efficient via "pack_padded_sequence" | |
emb_k = self.fc_k(x) | |
emb_q = self.fc_q(x) | |
s = torch.bmm(emb_k, emb_q.transpose(1, 2)) | |
s_mask = s.data.new(*s.size()).fill_(1).bool() # [B, T1, T2] | |
# Init similarity mask using lengths | |
for i, (l_1, l_2) in enumerate(zip(length, length)): | |
s_mask[i][:l_1, :l_2] = 0 | |
s_mask = Variable(s_mask) | |
s.data.masked_fill_(s_mask.data, -float("inf")) | |
a_weight = F.softmax(s, dim=2) # [B, t1, t2] | |
# remove nan from softmax on -inf | |
a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0) | |
return a_weight | |
def forward(self, x, length): | |
adj_sim = self.construct_graph(x, length) | |
for gc in self.gcs: | |
x = F.relu(gc(x, adj_sim)) | |
x = F.dropout(x, self.dropout, training=self.training) | |
return x | |
class GCN(nn.Module): | |
def __init__(self, dim_in, dim_hidden, dim_out, dropout, mode, skip, num_layers, ST_n_next=None): | |
super(GCN, self).__init__() | |
assert len(mode) != 0 | |
self.mode = mode | |
self.skip = skip | |
if "GCN_sim" in mode: | |
self.GCN_sim = GCN_sim( | |
dim_in, dim_hidden, dim_out, dropout, num_layers) | |
def forward(self, x, length): | |
out = [] | |
if "GCN_sim" in self.mode: | |
out.append(self.GCN_sim(x, length)) | |
out = sum(out) | |
if self.skip: | |
out += x | |
return out | |
if __name__ == '__main__': | |
model = GCN(512, 128, 512, 0.5, mode=[ | |
"GCN_sim"], skip=True, num_layers=3, ST_n_next=3) | |
bs, T, N = 10, 5, 10 | |
n_node = T*N | |
input = torch.rand(bs, n_node, 512) | |
length = torch.ones((bs)) | |
length = length.type(torch.IntTensor) | |
bboxes = torch.rand((bs, 5, 10, 4)) | |
output = model(input, length) | |