|
import random |
|
from typing import Optional, List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from src.constants import UNET_LAYERS |
|
from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder |
|
from src.utils.types import PESigmas |
|
|
|
|
|
class NeTIMapper(nn.Module): |
|
""" Main logic of our NeTI mapper. """ |
|
|
|
def __init__(self, output_dim: int = 768, |
|
unet_layers: List[str] = UNET_LAYERS, |
|
use_nested_dropout: bool = True, |
|
nested_dropout_prob: float = 0.5, |
|
norm_scale: Optional[torch.Tensor] = None, |
|
use_positional_encoding: bool = True, |
|
num_pe_time_anchors: int = 10, |
|
pe_sigmas: PESigmas = PESigmas(sigma_t=0.03, sigma_l=2.0), |
|
output_bypass: bool = True): |
|
super().__init__() |
|
self.use_nested_dropout = use_nested_dropout |
|
self.nested_dropout_prob = nested_dropout_prob |
|
self.norm_scale = norm_scale |
|
self.output_bypass = output_bypass |
|
if self.output_bypass: |
|
output_dim *= 2 |
|
|
|
self.use_positional_encoding = use_positional_encoding |
|
if self.use_positional_encoding: |
|
self.encoder = NeTIPositionalEncoding(sigma_t=pe_sigmas.sigma_t, sigma_l=pe_sigmas.sigma_l).cuda() |
|
self.input_dim = num_pe_time_anchors * len(unet_layers) |
|
else: |
|
self.encoder = BasicEncoder().cuda() |
|
self.input_dim = 2 |
|
|
|
self.set_net(num_unet_layers=len(unet_layers), |
|
num_time_anchors=num_pe_time_anchors, |
|
output_dim=output_dim) |
|
|
|
def set_net(self, num_unet_layers: int, num_time_anchors: int, output_dim: int = 768): |
|
self.input_layer = self.set_input_layer(num_unet_layers, num_time_anchors) |
|
self.net = nn.Sequential(self.input_layer, |
|
nn.Linear(self.input_dim, 128), nn.LayerNorm(128), nn.LeakyReLU(), |
|
nn.Linear(128, 128), nn.LayerNorm(128), nn.LeakyReLU()) |
|
self.output_layer = nn.Sequential(nn.Linear(128, output_dim)) |
|
|
|
def set_input_layer(self, num_unet_layers: int, num_time_anchors: int) -> nn.Module: |
|
if self.use_positional_encoding: |
|
input_layer = nn.Linear(self.encoder.num_w * 2, self.input_dim) |
|
input_layer.weight.data = self.encoder.init_layer(num_time_anchors, num_unet_layers) |
|
else: |
|
input_layer = nn.Identity() |
|
return input_layer |
|
|
|
def forward(self, timestep: torch.Tensor, unet_layer: torch.Tensor, truncation_idx: int = None) -> torch.Tensor: |
|
embedding = self.extract_hidden_representation(timestep, unet_layer) |
|
if self.use_nested_dropout: |
|
embedding = self.apply_nested_dropout(embedding, truncation_idx=truncation_idx) |
|
embedding = self.get_output(embedding) |
|
return embedding |
|
|
|
def get_encoded_input(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor: |
|
return self.encoder.encode(timestep, unet_layer) |
|
|
|
def extract_hidden_representation(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor: |
|
encoded_input = self.get_encoded_input(timestep, unet_layer) |
|
embedding = self.net(encoded_input) |
|
return embedding |
|
|
|
def apply_nested_dropout(self, embedding: torch.Tensor, truncation_idx: int = None) -> torch.Tensor: |
|
if self.training: |
|
if random.random() < self.nested_dropout_prob: |
|
dropout_idxs = torch.randint(low=0, high=embedding.shape[1], size=(embedding.shape[0],)) |
|
for idx in torch.arange(embedding.shape[0]): |
|
embedding[idx][dropout_idxs[idx]:] = 0 |
|
if not self.training and truncation_idx is not None: |
|
for idx in torch.arange(embedding.shape[0]): |
|
embedding[idx][truncation_idx:] = 0 |
|
return embedding |
|
|
|
def get_output(self, embedding: torch.Tensor) -> torch.Tensor: |
|
embedding = self.output_layer(embedding) |
|
if self.norm_scale is not None: |
|
embedding = F.normalize(embedding, dim=-1) * self.norm_scale |
|
return embedding |
|
|