Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): | |
"""A general BatchNorm layer without input dimension check. | |
Reproduced from @kapily's work: | |
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) | |
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc | |
is `_check_input_dim` that is designed for tensor sanity checks. | |
The check has been bypassed in this class for the convenience of converting | |
SyncBatchNorm. | |
""" | |
def _check_input_dim(self, input): | |
return | |
def revert_sync_batchnorm(module): | |
"""Helper function to convert all `SyncBatchNorm` layers in the model to | |
`BatchNormXd` layers. | |
Adapted from @kapily's work: | |
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) | |
Args: | |
module (nn.Module): The module containing `SyncBatchNorm` layers. | |
Returns: | |
module_output: The converted module with `BatchNormXd` layers. | |
""" | |
module_output = module | |
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): | |
module_output = _BatchNormXd(module.num_features, module.eps, | |
module.momentum, module.affine, | |
module.track_running_stats) | |
if module.affine: | |
with torch.no_grad(): | |
module_output.weight = module.weight | |
module_output.bias = module.bias | |
module_output.running_mean = module.running_mean | |
module_output.running_var = module.running_var | |
module_output.num_batches_tracked = module.num_batches_tracked | |
module_output.training = module.training | |
if hasattr(module, 'qconfig'): | |
module_output.qconfig = module.qconfig | |
for name, child in module.named_children(): | |
module_output.add_module(name, revert_sync_batchnorm(child)) | |
del module | |
return module_output | |