Spaces:
Running
on
T4
Running
on
T4
import torch | |
from torch import nn | |
class PositionalEncodingsFixed(nn.Module): | |
def __init__(self, emb_dim, temperature=10000): | |
super(PositionalEncodingsFixed, self).__init__() | |
self.emb_dim = emb_dim | |
self.temperature = temperature | |
def _1d_pos_enc(self, mask, dim): | |
temp = torch.arange(self.emb_dim // 2).float().to(mask.device) | |
temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim) | |
enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp | |
enc = torch.stack([ | |
enc[..., 0::2].sin(), enc[..., 1::2].cos() | |
], dim=-1).flatten(-2) | |
return enc | |
def forward(self, bs, h, w, device): | |
mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device) | |
x = self._1d_pos_enc(mask, dim=2) | |
y = self._1d_pos_enc(mask, dim=1) | |
return torch.cat([y, x], dim=3).permute(0, 3, 1, 2) | |