# Copyright (c) OpenMMLab. All rights reserved. from queue import PriorityQueue import torch import torch.nn.functional as F import mmocr.utils as utils from mmocr.models.builder import DECODERS from . import ParallelSARDecoder class DecodeNode: """Node class to save decoded char indices and scores. Args: indexes (list[int]): Char indices that decoded yes. scores (list[float]): Char scores that decoded yes. """ def __init__(self, indexes=[1], scores=[0.9]): assert utils.is_type_list(indexes, int) assert utils.is_type_list(scores, float) assert utils.equal_len(indexes, scores) self.indexes = indexes self.scores = scores def eval(self): """Calculate accumulated score.""" accu_score = sum(self.scores) return accu_score @DECODERS.register_module() class ParallelSARDecoderWithBS(ParallelSARDecoder): """Parallel Decoder module with beam-search in SAR. Args: beam_width (int): Width for beam search. """ def __init__(self, beam_width=5, num_classes=37, enc_bi_rnn=False, dec_bi_rnn=False, dec_do_rnn=0, dec_gru=False, d_model=512, d_enc=512, d_k=64, pred_dropout=0.0, max_seq_len=40, mask=True, start_idx=0, padding_idx=0, pred_concat=False, init_cfg=None, **kwargs): super().__init__( num_classes, enc_bi_rnn, dec_bi_rnn, dec_do_rnn, dec_gru, d_model, d_enc, d_k, pred_dropout, max_seq_len, mask, start_idx, padding_idx, pred_concat, init_cfg=init_cfg) assert isinstance(beam_width, int) assert beam_width > 0 self.beam_width = beam_width def forward_test(self, feat, out_enc, img_metas): assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None seq_len = self.max_seq_len bsz = feat.size(0) assert bsz == 1, 'batch size must be 1 for beam search.' start_token = torch.full((bsz, ), self.start_idx, device=feat.device, dtype=torch.long) # bsz start_token = self.embedding(start_token) # bsz * emb_dim start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1) # bsz * seq_len * emb_dim out_enc = out_enc.unsqueeze(1) # bsz * 1 * emb_dim decoder_input = torch.cat((out_enc, start_token), dim=1) # bsz * (seq_len + 1) * emb_dim # Initialize beam-search queue q = PriorityQueue() init_node = DecodeNode([self.start_idx], [0.0]) q.put((-init_node.eval(), init_node)) for i in range(1, seq_len + 1): next_nodes = [] beam_width = self.beam_width if i > 1 else 1 for _ in range(beam_width): _, node = q.get() input_seq = torch.clone(decoder_input) # bsz * T * emb_dim # fill previous input tokens (step 1...i) in input_seq for t, index in enumerate(node.indexes): input_token = torch.full((bsz, ), index, device=input_seq.device, dtype=torch.long) input_token = self.embedding(input_token) # bsz * emb_dim input_seq[:, t + 1, :] = input_token output_seq = self._2d_attention( input_seq, feat, out_enc, valid_ratios=valid_ratios) output_char = output_seq[:, i, :] # bsz * num_classes output_char = F.softmax(output_char, -1) topk_value, topk_idx = output_char.topk(self.beam_width, dim=1) topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze( 0) for k in range(self.beam_width): kth_score = topk_value[k].item() kth_idx = topk_idx[k].item() next_node = DecodeNode(node.indexes + [kth_idx], node.scores + [kth_score]) delta = k * 1e-6 next_nodes.append( (-node.eval() - kth_score - delta, next_node)) # Use minus since priority queue sort # with ascending order while not q.empty(): q.get() # Put all candidates to queue for next_node in next_nodes: q.put(next_node) best_node = q.get() num_classes = self.num_classes - 1 # ignore padding index outputs = torch.zeros(bsz, seq_len, num_classes) for i in range(seq_len): idx = best_node[1].indexes[i + 1] outputs[0, i, idx] = best_node[1].scores[i + 1] return outputs