File size: 4,192 Bytes
3eb1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import random
from typing import Optional, List

import torch
import torch.nn.functional as F
from torch import nn

from constants import UNET_LAYERS
from models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
from utils.types import PESigmas


class NeTIMapper(nn.Module):
    """ Main logic of our NeTI mapper. """

    def __init__(self, output_dim: int = 768,
                 unet_layers: List[str] = UNET_LAYERS,
                 use_nested_dropout: bool = True,
                 nested_dropout_prob: float = 0.5,
                 norm_scale: Optional[torch.Tensor] = None,
                 use_positional_encoding: bool = True,
                 num_pe_time_anchors: int = 10,
                 pe_sigmas: PESigmas = PESigmas(sigma_t=0.03, sigma_l=2.0),
                 output_bypass: bool = True):
        super().__init__()
        self.use_nested_dropout = use_nested_dropout
        self.nested_dropout_prob = nested_dropout_prob
        self.norm_scale = norm_scale
        self.output_bypass = output_bypass
        if self.output_bypass:
            output_dim *= 2  # Output two vectors

        self.use_positional_encoding = use_positional_encoding
        if self.use_positional_encoding:
            self.encoder = NeTIPositionalEncoding(sigma_t=pe_sigmas.sigma_t, sigma_l=pe_sigmas.sigma_l).cuda()
            self.input_dim = num_pe_time_anchors * len(unet_layers)
        else:
            self.encoder = BasicEncoder().cuda()
            self.input_dim = 2

        self.set_net(num_unet_layers=len(unet_layers),
                     num_time_anchors=num_pe_time_anchors,
                     output_dim=output_dim)

    def set_net(self, num_unet_layers: int, num_time_anchors: int, output_dim: int = 768):
        self.input_layer = self.set_input_layer(num_unet_layers, num_time_anchors)
        self.net = nn.Sequential(self.input_layer,
                                 nn.Linear(self.input_dim, 128), nn.LayerNorm(128), nn.LeakyReLU(),
                                 nn.Linear(128, 128), nn.LayerNorm(128), nn.LeakyReLU())
        self.output_layer = nn.Sequential(nn.Linear(128, output_dim))

    def set_input_layer(self, num_unet_layers: int, num_time_anchors: int) -> nn.Module:
        if self.use_positional_encoding:
            input_layer = nn.Linear(self.encoder.num_w * 2, self.input_dim)
            input_layer.weight.data = self.encoder.init_layer(num_time_anchors, num_unet_layers)
        else:
            input_layer = nn.Identity()
        return input_layer

    def forward(self, timestep: torch.Tensor, unet_layer: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
        embedding = self.extract_hidden_representation(timestep, unet_layer)
        if self.use_nested_dropout:
            embedding = self.apply_nested_dropout(embedding, truncation_idx=truncation_idx)
        embedding = self.get_output(embedding)
        return embedding

    def get_encoded_input(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
        return self.encoder.encode(timestep, unet_layer)

    def extract_hidden_representation(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
        encoded_input = self.get_encoded_input(timestep, unet_layer)
        embedding = self.net(encoded_input)
        return embedding

    def apply_nested_dropout(self, embedding: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
        if self.training:
            if random.random() < self.nested_dropout_prob:
                dropout_idxs = torch.randint(low=0, high=embedding.shape[1], size=(embedding.shape[0],))
                for idx in torch.arange(embedding.shape[0]):
                    embedding[idx][dropout_idxs[idx]:] = 0
        if not self.training and truncation_idx is not None:
            for idx in torch.arange(embedding.shape[0]):
                embedding[idx][truncation_idx:] = 0
        return embedding

    def get_output(self, embedding: torch.Tensor) -> torch.Tensor:
        embedding = self.output_layer(embedding)
        if self.norm_scale is not None:
            embedding = F.normalize(embedding, dim=-1) * self.norm_scale
        return embedding