# Copyright (c) OpenMMLab. All rights reserved. import math import torch import torch.nn as nn import torch.nn.functional as F from mmcv.runner import ModuleList from mmocr.models.builder import DECODERS from mmocr.models.common import PositionalEncoding, TFDecoderLayer from .base_decoder import BaseDecoder @DECODERS.register_module() class NRTRDecoder(BaseDecoder): """Transformer Decoder block with self attention mechanism. Args: n_layers (int): Number of attention layers. d_embedding (int): Language embedding dimension. n_head (int): Number of parallel attention heads. d_k (int): Dimension of the key vector. d_v (int): Dimension of the value vector. d_model (int): Dimension :math:`D_m` of the input from previous model. d_inner (int): Hidden dimension of feedforward layers. n_position (int): Length of the positional encoding vector. Must be greater than ``max_seq_len``. dropout (float): Dropout rate. num_classes (int): Number of output classes :math:`C`. max_seq_len (int): Maximum output sequence length :math:`T`. start_idx (int): The index of ``. padding_idx (int): The index of ``. init_cfg (dict or list[dict], optional): Initialization configs. Warning: This decoder will not predict the final class which is assumed to be ``. Therefore, its output size is always :math:`C - 1`. `` is also ignored by loss as specified in :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. """ def __init__(self, n_layers=6, d_embedding=512, n_head=8, d_k=64, d_v=64, d_model=512, d_inner=256, n_position=200, dropout=0.1, num_classes=93, max_seq_len=40, start_idx=1, padding_idx=92, init_cfg=None, **kwargs): super().__init__(init_cfg=init_cfg) self.padding_idx = padding_idx self.start_idx = start_idx self.max_seq_len = max_seq_len self.trg_word_emb = nn.Embedding( num_classes, d_embedding, padding_idx=padding_idx) self.position_enc = PositionalEncoding( d_embedding, n_position=n_position) self.dropout = nn.Dropout(p=dropout) self.layer_stack = ModuleList([ TFDecoderLayer( d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) pred_num_class = num_classes - 1 # ignore padding_idx self.classifier = nn.Linear(d_model, pred_num_class) @staticmethod def get_pad_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(-2) @staticmethod def get_subsequent_mask(seq): """For masking out the subsequent info.""" len_s = seq.size(1) subsequent_mask = 1 - torch.triu( torch.ones((len_s, len_s), device=seq.device), diagonal=1) subsequent_mask = subsequent_mask.unsqueeze(0).bool() return subsequent_mask def _attention(self, trg_seq, src, src_mask=None): trg_embedding = self.trg_word_emb(trg_seq) trg_pos_encoded = self.position_enc(trg_embedding) tgt = self.dropout(trg_pos_encoded) trg_mask = self.get_pad_mask( trg_seq, pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq) output = tgt for dec_layer in self.layer_stack: output = dec_layer( output, src, self_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) output = self.layer_norm(output) return output def _get_mask(self, logit, img_metas): valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] N, T, _ = logit.size() mask = None if valid_ratios is not None: mask = logit.new_zeros((N, T)) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(T, math.ceil(T * valid_ratio)) mask[i, :valid_width] = 1 return mask def forward_train(self, feat, out_enc, targets_dict, img_metas): r""" Args: feat (None): Unused. out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: The raw logit tensor. Shape :math:`(N, T, C)`. """ src_mask = self._get_mask(out_enc, img_metas) targets = targets_dict['padded_targets'].to(out_enc.device) attn_output = self._attention(targets, out_enc, src_mask=src_mask) outputs = self.classifier(attn_output) return outputs def forward_test(self, feat, out_enc, img_metas): src_mask = self._get_mask(out_enc, img_metas) N = out_enc.size(0) init_target_seq = torch.full((N, self.max_seq_len + 1), self.padding_idx, device=out_enc.device, dtype=torch.long) # bsz * seq_len init_target_seq[:, 0] = self.start_idx outputs = [] for step in range(0, self.max_seq_len): decoder_output = self._attention( init_target_seq, out_enc, src_mask=src_mask) # bsz * seq_len * C step_result = F.softmax( self.classifier(decoder_output[:, step, :]), dim=-1) # bsz * num_classes outputs.append(step_result) _, step_max_index = torch.max(step_result, dim=-1) init_target_seq[:, step + 1] = step_max_index outputs = torch.stack(outputs, dim=1) return outputs