File size: 376 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmocr.models.textrecog.fusers import ABIFuser


def test_base_alignment():
    model = ABIFuser(d_model=512, num_chars=90, max_seq_len=40)
    l_feat = torch.randn(1, 40, 512)
    v_feat = torch.randn(1, 40, 512)
    result = model(l_feat, v_feat)
    assert result['logits'].shape == torch.Size([1, 40, 90])