File size: 561 Bytes
810c8ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN


def test_unetdiscriminatorsn():
    """Test arch: UNetDiscriminatorSN."""

    # model init and forward (cpu)
    net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
    img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
    output = net(img)
    assert output.shape == (1, 1, 32, 32)

    # model init and forward (gpu)
    if torch.cuda.is_available():
        net.cuda()
        output = net(img.cuda())
        assert output.shape == (1, 1, 32, 32)