Spaces:
Runtime error
Runtime error
File size: 740 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.cnn.bricks import ConvModule
from mmocr.utils import revert_sync_batchnorm
def test_revert_sync_batchnorm():
conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu')
conv_syncbn.train()
x = torch.randn(1, 3, 10, 10)
# Will raise an ValueError saying SyncBN does not run on CPU
with pytest.raises(ValueError):
y = conv_syncbn(x)
conv_bn = revert_sync_batchnorm(conv_syncbn)
y = conv_bn(x)
assert y.shape == (1, 8, 9, 9)
assert conv_bn.training == conv_syncbn.training
conv_syncbn.eval()
conv_bn = revert_sync_batchnorm(conv_syncbn)
assert conv_bn.training == conv_syncbn.training
|