File size: 2,091 Bytes
06f3e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig


logger = logging.get_logger(__name__)


class OpenBAConfig(PretrainedConfig):
    model_type = "openba"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "hidden_size",
        "num_attention_heads": "num_heads",
        "num_hidden_layers": "num_layers"
    }

    def __init__(
        self,
        vocab_size=32128,
        hidden_size=512,
        kv_channels=64,
        ffn_hidden_size=2048,
        num_layers=12,
        num_decoder_layers=None,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        num_heads=8,
        is_encoder_decoder=True,
        use_cache=True,
        initializer_factor=1.0,
        pad_token_id=0,
        eos_token_id=1,
        decoder_start_token_id=0,
        add_qkv_bias=False,
        add_ffn_bias=False,
        add_lm_head_bias=False,
        max_seq_length=1024,
        decoder_max_seq_length=256,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.kv_channels = kv_channels
        self.ffn_hidden_size = ffn_hidden_size
        self.num_layers = num_layers
        self.num_decoder_layers = (
            num_decoder_layers if num_decoder_layers is not None else self.num_layers
        )  # default = symmetry
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout
        self.initializer_factor = initializer_factor
        self.num_heads = num_heads
        self.add_qkv_bias = add_qkv_bias
        self.add_ffn_bias = add_ffn_bias
        self.add_lm_head_bias = add_lm_head_bias
        self.max_seq_length = max_seq_length
        self.decoder_max_seq_length = decoder_max_seq_length
        self.use_cache = use_cache

        super().__init__(
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            decoder_start_token_id=decoder_start_token_id,
            is_encoder_decoder=is_encoder_decoder,
            **kwargs,
        )