File size: 911 Bytes
d96021f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class RITAConfig(PretrainedConfig):
model_type = "codegen"
def __init__(
self,
vocab_size=128,
d_model=768,
num_layers=12,
max_seq_len=1024,
num_heads=12,
dropout=0.,
ff_ratio=4,
bos_token_id=50256, # TODO
eos_token_id=50256, # TODO
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.d_feedforward = d_model*ff_ratio
self.num_layers = num_layers
self.max_seq_len=max_seq_len
self.dropout = dropout
self.bos_token_id=bos_token_id,
self.eos_token_id=eos_token_id
|