|
import torch |
|
from torch import Tensor, nn |
|
from transformers import PreTrainedModel |
|
from .config import AdapterConfig |
|
|
|
|
|
class Model(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
num_channels: int, |
|
num_filters: int, |
|
window_length: int, |
|
stride: int, |
|
): |
|
super().__init__() |
|
self.stride = stride |
|
padding = window_length // 2 - stride // 2 |
|
self.conv = nn.Conv1d( |
|
in_channels=num_channels, |
|
out_channels=num_filters, |
|
kernel_size=window_length, |
|
stride=stride, |
|
padding=padding, |
|
padding_mode="reflect", |
|
bias=False, |
|
) |
|
self.decode = nn.ConvTranspose1d( |
|
in_channels=num_filters, |
|
out_channels=num_channels, |
|
kernel_size=window_length, |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
|
|
def encode(self, x: Tensor) -> Tensor: |
|
return torch.tanh(self.conv(x)) |
|
|
|
|
|
class Adapter(PreTrainedModel): |
|
|
|
config_class = AdapterConfig |
|
|
|
def __init__(self, config: AdapterConfig): |
|
super().__init__(config) |
|
|
|
self.model = Model( |
|
num_channels=2, |
|
num_filters=128, |
|
window_length=128, |
|
stride=64 |
|
) |
|
|
|
def encode(self, x): |
|
return self.model.encode(x) |
|
|
|
def decode(self, x): |
|
return self.model.decode(x) |
|
|
|
|
|
|