Spaces:
Paused
Paused
from typing import Tuple, Union | |
import torch | |
import torch.nn as nn | |
class CausalConv3d(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size: int = 3, | |
stride: Union[int, Tuple[int]] = 1, | |
dilation: int = 1, | |
groups: int = 1, | |
**kwargs, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
kernel_size = (kernel_size, kernel_size, kernel_size) | |
self.time_kernel_size = kernel_size[0] | |
dilation = (dilation, 1, 1) | |
height_pad = kernel_size[1] // 2 | |
width_pad = kernel_size[2] // 2 | |
padding = (0, height_pad, width_pad) | |
self.conv = nn.Conv3d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
padding_mode="zeros", | |
groups=groups, | |
) | |
def forward(self, x, causal: bool = True): | |
if causal: | |
first_frame_pad = x[:, :, :1, :, :].repeat( | |
(1, 1, self.time_kernel_size - 1, 1, 1) | |
) | |
x = torch.concatenate((first_frame_pad, x), dim=2) | |
else: | |
first_frame_pad = x[:, :, :1, :, :].repeat( | |
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1) | |
) | |
last_frame_pad = x[:, :, -1:, :, :].repeat( | |
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1) | |
) | |
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) | |
x = self.conv(x) | |
return x | |
def weight(self): | |
return self.conv.weight | |