File size: 2,777 Bytes
c69543f 4f41cdf c69543f 4f41cdf c69543f 4f41cdf c69543f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import sys
from collections import OrderedDict
from typing import Mapping
import logging
from transformers import T5Config
AUTO_MAP = {
"AutoModel": "modeling_flash_t5.FlashT5EncoderModel",
"AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
"AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
"AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
"AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
}
class FlashT5Config(T5Config):
model_type = "flash_t5"
def __init__(
self,
decoder_start_token_id=0,
pad_token_id=-100,
use_glu_mlp=False,
position_encoding_type="t5",
use_randomized_position_encoding=False,
label_smoothing=0.0,
z_loss=None,
use_flash_attention=None,
max_sequence_length=1024,
attention_dropout_rate=0.0,
alibi_mode="symetric",
use_triton_layernorm=False,
use_triton_crossentropy=False,
use_triton_gated_mlp=False,
use_gelu_act=True,
use_full_bias_size=False,
rotary_emb_fraction=1.0,
rotary_base=10000,
rotary_interleaved=False,
rotary_scale_base=None,
**kwargs,
):
super().__init__(**kwargs)
self.decoder_start_token_id = decoder_start_token_id
self.pad_token_id = pad_token_id
self.use_glu_mlp = use_glu_mlp
self.position_encoding_type = position_encoding_type
self.use_randomized_position_encoding = use_randomized_position_encoding
self.label_smoothing = label_smoothing
self.z_loss = z_loss
self.use_flash_attention = use_flash_attention
self.max_sequence_length = max_sequence_length
self.alibi_mode = alibi_mode
self.attention_dropout_rate = attention_dropout_rate
self.use_triton_layernorm = use_triton_layernorm
self.use_triton_crossentropy = use_triton_crossentropy
self.use_triton_gated_mlp = use_triton_gated_mlp
self.use_gelu_act = use_gelu_act
self.use_full_bias_size = use_full_bias_size
self.rotary_base = rotary_base
self.rotary_interleaved = rotary_interleaved
self.rotary_scale_base = rotary_scale_base
self.rotary_emb_fraction = rotary_emb_fraction
self.auto_map = AUTO_MAP
def str_to_class(classname):
return getattr(sys.modules[__name__], classname)
# Register model in Auto API
try:
FlashT5Config.register_for_auto_class()
for key, value in AUTO_MAP.items():
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
except:
logging.warn("AutoRegister isn't available.")
|