Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
__all__ = ['SharedMLP'] | |
class Swish(nn.Module): | |
def forward(self,x): | |
return x * torch.sigmoid(x) | |
class SharedMLP(nn.Module): | |
def __init__(self, in_channels, out_channels, dim=1): | |
super().__init__() | |
if dim == 1: # default value | |
conv = nn.Conv1d | |
bn = nn.GroupNorm | |
elif dim == 2: | |
conv = nn.Conv2d | |
bn = nn.GroupNorm | |
else: | |
raise ValueError | |
if not isinstance(out_channels, (list, tuple)): | |
out_channels = [out_channels] | |
layers = [] | |
for oc in out_channels: | |
layers.extend([ | |
conv(in_channels, oc, 1), | |
bn(8, oc), | |
Swish(), | |
]) | |
in_channels = oc | |
self.layers = nn.Sequential(*layers) | |
def forward(self, inputs): | |
if isinstance(inputs, (list, tuple)): | |
return (self.layers(inputs[0]), *inputs[1:]) | |
else: | |
return self.layers(inputs) | |