|
from typing import Optional |
|
import torch.nn as nn |
|
|
|
|
|
def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: |
|
""" |
|
Overview: |
|
Construct the corresponding normalization module. For beginners, |
|
refer to [this article](https://zhuanlan.zhihu.com/p/34879333) to learn more about batch normalization. |
|
Arguments: |
|
- norm_type (:obj:`str`): Type of the normalization. Currently supports ['BN', 'LN', 'IN', 'SyncBN']. |
|
- dim (:obj:`Optional[int]`): Dimension of the normalization, applicable when norm_type is in ['BN', 'IN']. |
|
Returns: |
|
- norm_func (:obj:`nn.Module`): The corresponding batch normalization function. |
|
""" |
|
if dim is None: |
|
key = norm_type |
|
else: |
|
if norm_type in ['BN', 'IN']: |
|
key = norm_type + str(dim) |
|
elif norm_type in ['LN', 'SyncBN']: |
|
key = norm_type |
|
else: |
|
raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) |
|
norm_func = { |
|
'BN1': nn.BatchNorm1d, |
|
'BN2': nn.BatchNorm2d, |
|
'LN': nn.LayerNorm, |
|
'IN1': nn.InstanceNorm1d, |
|
'IN2': nn.InstanceNorm2d, |
|
'SyncBN': nn.SyncBatchNorm, |
|
} |
|
if key in norm_func.keys(): |
|
return norm_func[key] |
|
else: |
|
raise KeyError("invalid norm type: {}".format(key)) |
|
|