File size: 1,893 Bytes
6ff2047 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
def test_resnetarcface():
"""Test arch: ResNetArcFace."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
output = net(img)
assert output.shape == (1, 512)
# -------------------- without SE block ----------------------- #
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
output = net(img)
assert output.shape == (1, 512)
def test_basicblock():
"""Test the BasicBlock in arcface_arch"""
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 12, 12)
# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 6, 6)
def test_bottleneck():
"""Test the Bottleneck in arcface_arch"""
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 12, 12)
# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 6, 6)
|