File size: 444 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import torch.nn as nn
from torch import Tensor
class DeployC2f(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x: Tensor) -> Tensor:
x_main = self.main_conv(x)
x_main = [x_main, x_main[:, self.mid_channels:, ...]]
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
x_main.pop(1)
return self.final_conv(torch.cat(x_main, 1))
|