MMOCR / mmocr /models /ner /encoders /bert_encoder.py
tomofi's picture
Add application file
2366e36
raw
history blame
3.14 kB
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from mmocr.models.builder import ENCODERS
from mmocr.models.ner.utils.bert import BertModel
@ENCODERS.register_module()
class BertEncoder(BaseModule):
"""Bert encoder
Args:
num_hidden_layers (int): The number of hidden layers.
initializer_range (float):
vocab_size (int): Number of words supported.
hidden_size (int): Hidden size.
max_position_embeddings (int): Max positions embedding size.
type_vocab_size (int): The size of type_vocab.
layer_norm_eps (float): Epsilon of layer norm.
hidden_dropout_prob (float): The dropout probability of hidden layer.
output_attentions (bool): Whether use the attentions in output.
output_hidden_states (bool): Whether use the hidden_states in output.
num_attention_heads (int): The number of attention heads.
attention_probs_dropout_prob (float): The dropout probability
of attention.
intermediate_size (int): The size of intermediate layer.
hidden_act_cfg (dict): Hidden layer activation.
"""
def __init__(self,
num_hidden_layers=12,
initializer_range=0.02,
vocab_size=21128,
hidden_size=768,
max_position_embeddings=128,
type_vocab_size=2,
layer_norm_eps=1e-12,
hidden_dropout_prob=0.1,
output_attentions=False,
output_hidden_states=False,
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
intermediate_size=3072,
hidden_act_cfg=dict(type='GeluNew'),
init_cfg=[
dict(type='Xavier', layer='Conv2d'),
dict(type='Uniform', layer='BatchNorm2d')
]):
super().__init__(init_cfg=init_cfg)
self.bert = BertModel(
num_hidden_layers=num_hidden_layers,
initializer_range=initializer_range,
vocab_size=vocab_size,
hidden_size=hidden_size,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
intermediate_size=intermediate_size,
hidden_act_cfg=hidden_act_cfg)
def forward(self, results):
device = next(self.bert.parameters()).device
input_ids = results['input_ids'].to(device)
attention_masks = results['attention_masks'].to(device)
token_type_ids = results['token_type_ids'].to(device)
outputs = self.bert(
input_ids=input_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids)
return outputs