MMOCR / mmocr /models /textrecog /decoders /sar_decoder_with_bs.py
tomofi's picture
Add application file
2366e36
raw
history blame
5.47 kB
# Copyright (c) OpenMMLab. All rights reserved.
from queue import PriorityQueue
import torch
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import DECODERS
from . import ParallelSARDecoder
class DecodeNode:
"""Node class to save decoded char indices and scores.
Args:
indexes (list[int]): Char indices that decoded yes.
scores (list[float]): Char scores that decoded yes.
"""
def __init__(self, indexes=[1], scores=[0.9]):
assert utils.is_type_list(indexes, int)
assert utils.is_type_list(scores, float)
assert utils.equal_len(indexes, scores)
self.indexes = indexes
self.scores = scores
def eval(self):
"""Calculate accumulated score."""
accu_score = sum(self.scores)
return accu_score
@DECODERS.register_module()
class ParallelSARDecoderWithBS(ParallelSARDecoder):
"""Parallel Decoder module with beam-search in SAR.
Args:
beam_width (int): Width for beam search.
"""
def __init__(self,
beam_width=5,
num_classes=37,
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
d_model=512,
d_enc=512,
d_k=64,
pred_dropout=0.0,
max_seq_len=40,
mask=True,
start_idx=0,
padding_idx=0,
pred_concat=False,
init_cfg=None,
**kwargs):
super().__init__(
num_classes,
enc_bi_rnn,
dec_bi_rnn,
dec_do_rnn,
dec_gru,
d_model,
d_enc,
d_k,
pred_dropout,
max_seq_len,
mask,
start_idx,
padding_idx,
pred_concat,
init_cfg=init_cfg)
assert isinstance(beam_width, int)
assert beam_width > 0
self.beam_width = beam_width
def forward_test(self, feat, out_enc, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
seq_len = self.max_seq_len
bsz = feat.size(0)
assert bsz == 1, 'batch size must be 1 for beam search.'
start_token = torch.full((bsz, ),
self.start_idx,
device=feat.device,
dtype=torch.long)
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = torch.cat((out_enc, start_token), dim=1)
# bsz * (seq_len + 1) * emb_dim
# Initialize beam-search queue
q = PriorityQueue()
init_node = DecodeNode([self.start_idx], [0.0])
q.put((-init_node.eval(), init_node))
for i in range(1, seq_len + 1):
next_nodes = []
beam_width = self.beam_width if i > 1 else 1
for _ in range(beam_width):
_, node = q.get()
input_seq = torch.clone(decoder_input) # bsz * T * emb_dim
# fill previous input tokens (step 1...i) in input_seq
for t, index in enumerate(node.indexes):
input_token = torch.full((bsz, ),
index,
device=input_seq.device,
dtype=torch.long)
input_token = self.embedding(input_token) # bsz * emb_dim
input_seq[:, t + 1, :] = input_token
output_seq = self._2d_attention(
input_seq, feat, out_enc, valid_ratios=valid_ratios)
output_char = output_seq[:, i, :] # bsz * num_classes
output_char = F.softmax(output_char, -1)
topk_value, topk_idx = output_char.topk(self.beam_width, dim=1)
topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze(
0)
for k in range(self.beam_width):
kth_score = topk_value[k].item()
kth_idx = topk_idx[k].item()
next_node = DecodeNode(node.indexes + [kth_idx],
node.scores + [kth_score])
delta = k * 1e-6
next_nodes.append(
(-node.eval() - kth_score - delta, next_node))
# Use minus since priority queue sort
# with ascending order
while not q.empty():
q.get()
# Put all candidates to queue
for next_node in next_nodes:
q.put(next_node)
best_node = q.get()
num_classes = self.num_classes - 1 # ignore padding index
outputs = torch.zeros(bsz, seq_len, num_classes)
for i in range(seq_len):
idx = best_node[1].indexes[i + 1]
outputs[0, i, idx] = best_node[1].scores[i + 1]
return outputs