Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.runner import BaseModule | |
class RobustScannerFusionLayer(BaseModule): | |
def __init__(self, dim_model, dim=-1, init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.dim_model = dim_model | |
self.dim = dim | |
self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2) | |
self.glu_layer = nn.GLU(dim=dim) | |
def forward(self, x0, x1): | |
assert x0.size() == x1.size() | |
fusion_input = torch.cat([x0, x1], self.dim) | |
output = self.linear_layer(fusion_input) | |
output = self.glu_layer(output) | |
return output | |