File size: 3,967 Bytes
3dba732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
from typing import Tuple

from torchvision.models import resnet18, resnet50
from torchvision.models import ResNet18_Weights, ResNet50_Weights

class DistMult(nn.Module):
    def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
        super(DistMult, self).__init__()
        self.num_ent_uid = num_ent_uid

        self.num_relations = 4

        self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
        self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)

        self.location_embedding = MLP(2, 512, 3)

        self.time_embedding = MLP(1, 512, 3)
        
        self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.image_embedding.fc = nn.Linear(2048, 512)

        self.target_list = target_list
        
        if all_locs is not None:
            self.all_locs = all_locs.to(device)
        if all_timestamps is not None:
            self.all_timestamps = all_timestamps.to(device)

        self.device = device

        self.init()

    def init(self):
        nn.init.xavier_uniform_(self.ent_embedding.weight.data)
        nn.init.xavier_uniform_(self.rel_embedding.weight.data)
        nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)

    def forward_ce(self, h, r, triple_type=None):
        emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
        
        emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]

        emb_hr = emb_h * emb_r  # [batch, hid]

        if triple_type == ('image', 'id'):
            score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
        elif triple_type == ('id', 'id'):
            score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
        elif triple_type == ('image', 'location'):
            loc_emb = self.location_embedding(self.all_locs) # computed for each batch
            score = torch.mm(emb_hr, loc_emb.T)
        elif triple_type == ('image', 'time'):
            time_emb = self.time_embedding(self.all_timestamps)
            score = torch.mm(emb_hr, time_emb.T)
        else:
            raise NotImplementedError

        return score

    def batch_embedding_concat_h(self, e1):
        e1_embedded = None
        
        if len(e1.size())==1 or e1.size(1) == 1:  # uid
            # print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
            e1_embedded = self.ent_embedding(e1.squeeze(-1))
        elif e1.size(1) == 15:  # time
            e1_embedded = self.time_embedding(e1)
        elif e1.size(1) == 2:  # GPS
            e1_embedded = self.location_embedding(e1)
        elif e1.size(1) == 3:  # Image
            e1_embedded = self.image_embedding(e1)

        return e1_embedded


class MLP(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim,
                 num_layers=3,
                 p_dropout=0.0,
                 bias=True):

        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.p_dropout = p_dropout
        step_size = (input_dim - output_dim) // num_layers
        hidden_dims = [output_dim + (i * step_size)
                       for i in reversed(range(num_layers))]

        mlp = list()
        layer_indim = input_dim
        for hidden_dim in hidden_dims:
            mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
                        nn.Dropout(p=self.p_dropout, inplace=True),
                        nn.PReLU()])

            layer_indim = hidden_dim

        self.mlp = nn.Sequential(*mlp)

        # initialize weights
        self.init()

    def forward(self, x):
        return self.mlp(x)

    def init(self):
        for param in self.parameters():
            nn.init.uniform_(param)