Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class ChannelLastConv1d(nn.Conv1d): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = x.permute(0, 2, 1) | |
x = super().forward(x) | |
x = x.permute(0, 2, 1) | |
return x | |
# https://github.com/Stability-AI/sd3-ref | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
multiple_of: int = 256, | |
): | |
""" | |
Initialize the FeedForward module. | |
Args: | |
dim (int): Input dimension. | |
hidden_dim (int): Hidden dimension of the feedforward layer. | |
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. | |
Attributes: | |
w1 (ColumnParallelLinear): Linear transformation for the first layer. | |
w2 (RowParallelLinear): Linear transformation for the second layer. | |
w3 (ColumnParallelLinear): Linear transformation for the third layer. | |
""" | |
super().__init__() | |
hidden_dim = int(2 * hidden_dim / 3) | |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
def forward(self, x): | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
class ConvMLP(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
multiple_of: int = 256, | |
kernel_size: int = 3, | |
padding: int = 1, | |
): | |
""" | |
Initialize the FeedForward module. | |
Args: | |
dim (int): Input dimension. | |
hidden_dim (int): Hidden dimension of the feedforward layer. | |
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. | |
Attributes: | |
w1 (ColumnParallelLinear): Linear transformation for the first layer. | |
w2 (RowParallelLinear): Linear transformation for the second layer. | |
w3 (ColumnParallelLinear): Linear transformation for the third layer. | |
""" | |
super().__init__() | |
hidden_dim = int(2 * hidden_dim / 3) | |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
self.w1 = ChannelLastConv1d(dim, | |
hidden_dim, | |
bias=False, | |
kernel_size=kernel_size, | |
padding=padding) | |
self.w2 = ChannelLastConv1d(hidden_dim, | |
dim, | |
bias=False, | |
kernel_size=kernel_size, | |
padding=padding) | |
self.w3 = ChannelLastConv1d(dim, | |
hidden_dim, | |
bias=False, | |
kernel_size=kernel_size, | |
padding=padding) | |
def forward(self, x): | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |