tomofi's picture
Add application file
2366e36
raw
history blame
No virus
6.45 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from .base_decoder import BaseDecoder
@DECODERS.register_module()
class NRTRDecoder(BaseDecoder):
"""Transformer Decoder block with self attention mechanism.
Args:
n_layers (int): Number of attention layers.
d_embedding (int): Language embedding dimension.
n_head (int): Number of parallel attention heads.
d_k (int): Dimension of the key vector.
d_v (int): Dimension of the value vector.
d_model (int): Dimension :math:`D_m` of the input from previous model.
d_inner (int): Hidden dimension of feedforward layers.
n_position (int): Length of the positional encoding vector. Must be
greater than ``max_seq_len``.
dropout (float): Dropout rate.
num_classes (int): Number of output classes :math:`C`.
max_seq_len (int): Maximum output sequence length :math:`T`.
start_idx (int): The index of `<SOS>`.
padding_idx (int): The index of `<PAD>`.
init_cfg (dict or list[dict], optional): Initialization configs.
Warning:
This decoder will not predict the final class which is assumed to be
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
is also ignored by loss as specified in
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
"""
def __init__(self,
n_layers=6,
d_embedding=512,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
n_position=200,
dropout=0.1,
num_classes=93,
max_seq_len=40,
start_idx=1,
padding_idx=92,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
self.padding_idx = padding_idx
self.start_idx = start_idx
self.max_seq_len = max_seq_len
self.trg_word_emb = nn.Embedding(
num_classes, d_embedding, padding_idx=padding_idx)
self.position_enc = PositionalEncoding(
d_embedding, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = ModuleList([
TFDecoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
pred_num_class = num_classes - 1 # ignore padding_idx
self.classifier = nn.Linear(d_model, pred_num_class)
@staticmethod
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
@staticmethod
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.size(1)
subsequent_mask = 1 - torch.triu(
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
return subsequent_mask
def _attention(self, trg_seq, src, src_mask=None):
trg_embedding = self.trg_word_emb(trg_seq)
trg_pos_encoded = self.position_enc(trg_embedding)
tgt = self.dropout(trg_pos_encoded)
trg_mask = self.get_pad_mask(
trg_seq,
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
output = tgt
for dec_layer in self.layer_stack:
output = dec_layer(
output,
src,
self_attn_mask=trg_mask,
dec_enc_attn_mask=src_mask)
output = self.layer_norm(output)
return output
def _get_mask(self, logit, img_metas):
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
N, T, _ = logit.size()
mask = None
if valid_ratios is not None:
mask = logit.new_zeros((N, T))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
return mask
def forward_train(self, feat, out_enc, targets_dict, img_metas):
r"""
Args:
feat (None): Unused.
out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)`
where :math:`D_m` is ``d_model``.
targets_dict (dict): A dict with the key ``padded_targets``, a
tensor of shape :math:`(N, T)`. Each element is the index of a
character.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, T, C)`.
"""
src_mask = self._get_mask(out_enc, img_metas)
targets = targets_dict['padded_targets'].to(out_enc.device)
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
outputs = self.classifier(attn_output)
return outputs
def forward_test(self, feat, out_enc, img_metas):
src_mask = self._get_mask(out_enc, img_metas)
N = out_enc.size(0)
init_target_seq = torch.full((N, self.max_seq_len + 1),
self.padding_idx,
device=out_enc.device,
dtype=torch.long)
# bsz * seq_len
init_target_seq[:, 0] = self.start_idx
outputs = []
for step in range(0, self.max_seq_len):
decoder_output = self._attention(
init_target_seq, out_enc, src_mask=src_mask)
# bsz * seq_len * C
step_result = F.softmax(
self.classifier(decoder_output[:, step, :]), dim=-1)
# bsz * num_classes
outputs.append(step_result)
_, step_max_index = torch.max(step_result, dim=-1)
init_target_seq[:, step + 1] = step_max_index
outputs = torch.stack(outputs, dim=1)
return outputs