Spaces:
Runtime error
Runtime error
""" Normalization + Activation Layers | |
""" | |
import torch | |
from torch import nn as nn | |
from torch.nn import functional as F | |
from .create_act import get_act_layer | |
class BatchNormAct2d(nn.BatchNorm2d): | |
"""BatchNorm + Activation | |
This module performs BatchNorm + Activation in a manner that will remain backwards | |
compatible with weights trained with separate bn, act. This is why we inherit from BN | |
instead of composing it as a .bn member. | |
""" | |
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, | |
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): | |
super(BatchNormAct2d, self).__init__( | |
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) | |
if isinstance(act_layer, str): | |
act_layer = get_act_layer(act_layer) | |
if act_layer is not None and apply_act: | |
act_args = dict(inplace=True) if inplace else {} | |
self.act = act_layer(**act_args) | |
else: | |
self.act = nn.Identity() | |
def _forward_jit(self, x): | |
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function | |
""" | |
# exponential_average_factor is self.momentum set to | |
# (when it is available) only so that if gets updated | |
# in ONNX graph when this node is exported to ONNX. | |
if self.momentum is None: | |
exponential_average_factor = 0.0 | |
else: | |
exponential_average_factor = self.momentum | |
if self.training and self.track_running_stats: | |
# TODO: if statement only here to tell the jit to skip emitting this when it is None | |
if self.num_batches_tracked is not None: | |
self.num_batches_tracked += 1 | |
if self.momentum is None: # use cumulative moving average | |
exponential_average_factor = 1.0 / float(self.num_batches_tracked) | |
else: # use exponential moving average | |
exponential_average_factor = self.momentum | |
x = F.batch_norm( | |
x, self.running_mean, self.running_var, self.weight, self.bias, | |
self.training or not self.track_running_stats, | |
exponential_average_factor, self.eps) | |
return x | |
def _forward_python(self, x): | |
return super(BatchNormAct2d, self).forward(x) | |
def forward(self, x): | |
# FIXME cannot call parent forward() and maintain jit.script compatibility? | |
if torch.jit.is_scripting(): | |
x = self._forward_jit(x) | |
else: | |
x = self._forward_python(x) | |
x = self.act(x) | |
return x | |
class GroupNormAct(nn.GroupNorm): | |
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args | |
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, | |
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): | |
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) | |
if isinstance(act_layer, str): | |
act_layer = get_act_layer(act_layer) | |
if act_layer is not None and apply_act: | |
act_args = dict(inplace=True) if inplace else {} | |
self.act = act_layer(**act_args) | |
else: | |
self.act = nn.Identity() | |
def forward(self, x): | |
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) | |
x = self.act(x) | |
return x | |