|
from typing import Tuple |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class AlignmentNetwork(torch.nn.Module): |
|
"""Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. |
|
|
|
:: |
|
|
|
query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment |
|
key -> conv1d -> relu -> conv1d -----------------------^ |
|
|
|
Args: |
|
in_query_channels (int): Number of channels in the query network. Defaults to 80. |
|
in_key_channels (int): Number of channels in the key network. Defaults to 512. |
|
attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. |
|
temperature (float): Temperature for the softmax. Defaults to 0.0005. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_query_channels=80, |
|
in_key_channels=512, |
|
attn_channels=80, |
|
temperature=0.0005, |
|
): |
|
super().__init__() |
|
self.temperature = temperature |
|
self.softmax = torch.nn.Softmax(dim=3) |
|
self.log_softmax = torch.nn.LogSoftmax(dim=3) |
|
|
|
self.key_layer = nn.Sequential( |
|
nn.Conv1d( |
|
in_key_channels, |
|
in_key_channels * 2, |
|
kernel_size=3, |
|
padding=1, |
|
bias=True, |
|
), |
|
torch.nn.ReLU(), |
|
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), |
|
) |
|
|
|
self.query_layer = nn.Sequential( |
|
nn.Conv1d( |
|
in_query_channels, |
|
in_query_channels * 2, |
|
kernel_size=3, |
|
padding=1, |
|
bias=True, |
|
), |
|
torch.nn.ReLU(), |
|
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), |
|
torch.nn.ReLU(), |
|
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), |
|
) |
|
|
|
def forward( |
|
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None |
|
) -> Tuple[torch.tensor, torch.tensor]: |
|
"""Forward pass of the aligner encoder. |
|
Shapes: |
|
- queries: :math:`[B, C, T_de]` |
|
- keys: :math:`[B, C_emb, T_en]` |
|
- mask: :math:`[B, T_de]` |
|
Output: |
|
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. |
|
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. |
|
""" |
|
key_out = self.key_layer(keys) |
|
query_out = self.query_layer(queries) |
|
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 |
|
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) |
|
if attn_prior is not None: |
|
attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) |
|
if mask is not None: |
|
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) |
|
attn = self.softmax(attn_logp) |
|
return attn, attn_logp |
|
|