|
import monai |
|
import torch |
|
import itk |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from typing import Optional, Sequence, Tuple, Union |
|
from monai.networks.layers.factories import Act, Norm |
|
|
|
|
|
def segNet( |
|
spatial_dim: int, |
|
in_channel: int, |
|
out_channel: int, |
|
channel: Sequence[int], |
|
stride: Sequence[int], |
|
num_res_unit: int = 0, |
|
acts: Union[Tuple, str] = Act.PRELU, |
|
norms: Union[Tuple, str] = Norm.INSTANCE, |
|
dropouts: float = 0.0, |
|
): |
|
seg_net = monai.networks.nets.UNet( |
|
spatial_dims=spatial_dim, |
|
in_channels=in_channel, |
|
out_channels=out_channel, |
|
channels=channel, |
|
strides=stride, |
|
dropout=dropouts, |
|
act=acts, |
|
norm=norms, |
|
num_res_units=num_res_unit |
|
) |
|
return seg_net |
|
|
|
def regNet( |
|
spatial_dim: int, |
|
in_channel: int, |
|
out_channel: int, |
|
channel: Sequence[int], |
|
stride: Sequence[int], |
|
num_res_unit: int = 0, |
|
acts: Union[Tuple, str] = Act.PRELU, |
|
norms: Union[Tuple, str] = Norm.INSTANCE, |
|
dropouts: float = 0.0, |
|
): |
|
reg_net = monai.networks.nets.UNet( |
|
spatial_dims=spatial_dim, |
|
in_channels=in_channel, |
|
out_channels=out_channel, |
|
channels=channel, |
|
strides=stride, |
|
dropout=dropouts, |
|
act=acts, |
|
norm=norms, |
|
num_res_units=num_res_unit |
|
) |
|
return reg_net |
|
|