uform-gen2-dpo / vision_encoder.py
VoVoR's picture
Add: DPO model card init
0556e7f
raw
history blame
5.27 kB
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import Tensor
from typing import Optional
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
dropout_prob: float = 0
):
super().__init__()
self.use_sdp = int(torch.__version__[0]) > 1
self.query = nn.Linear(dim, dim)
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
self.out = nn.Linear(dim, dim)
self.dropout_prob = dropout_prob
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
def forward(
self,
x: Tensor,
attn_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None,
is_causal: bool = False,
) -> Tensor:
query = self.reshape(self.query(x))
key = self.reshape(self.key(x if context is None else context))
value = self.reshape(self.value(x if context is None else context))
if self.use_sdp:
x = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
dropout_p=self.dropout_prob if self.training else 0,
is_causal=is_causal,
)
else:
attn = query @ key.transpose(-2, -1) * self.scale
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
x = attn @ value
return self.out(x.transpose(2, 1).flatten(2))
def reshape(self, x: Tensor) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
return x.transpose(2, 1)
class MLP(nn.Module):
def __init__(
self,
dim: int,
dim_expand_factor: int = 4
):
super().__init__()
self.hidden_layer = nn.Linear(dim, dim * dim_expand_factor)
self.output_layer = nn.Linear(dim * dim_expand_factor, dim)
def forward(self, x: Tensor) -> Tensor:
x = F.gelu(self.hidden_layer(x))
return self.output_layer(x)
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False
):
super().__init__()
self.weight = nn.Parameter(init_values * torch.ones(dim))
self.inplace = inplace
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.weight) if self.inplace else x * self.weight
class VisionEncoderBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = Attention(dim, num_heads)
self.ls1 = LayerScale(dim)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = MLP(dim)
self.ls2 = LayerScale(dim)
def forward(self, x: Tensor) -> Tensor:
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class VisionEncoder(nn.Module):
def __init__(
self,
dim: int,
patch_size: int,
num_layers: int,
num_heads: int,
):
super().__init__()
self.n_patch = 224 // patch_size
self.seq_len = self.n_patch ** 2
self.patch_size = patch_size
self.patch_embed = nn.Conv2d(3, dim, patch_size, patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, self.seq_len, dim) * 0.02)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.interpolate_offset = 0.1
self.interpolate_antialias = False
self.blocks = nn.Sequential(
*[
VisionEncoderBlock(dim, num_heads)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(dim, eps=1e-6)
def interpolate_pos_encoding(self, x, h, w):
previous_dtype = x.dtype
if x.shape[1] == self.seq_len and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
sx, sy = float(w0) / self.n_patch, float(h0) / self.n_patch
pos_embed = nn.functional.interpolate(
pos_embed.reshape(1, self.n_patch, self.n_patch, dim).permute(0, 3, 1, 2),
scale_factor=(sy, sx),
mode="bicubic",
antialias=self.interpolate_antialias,
)
return pos_embed.to(previous_dtype).flatten(start_dim=2).transpose(2, 1)
def forward(self, x: Tensor) -> Tensor:
h, w = x.shape[2:]
x = self.patch_embed(x).flatten(start_dim=2).transpose(2, 1)
x = x + self.interpolate_pos_encoding(x, h, w)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.blocks(x)
return self.norm(x)