|
from transformers import PretrainedConfig |
|
|
|
class COMET19_CN_Config(PretrainedConfig): |
|
def __init__( |
|
self, |
|
model: str = "transformer", |
|
nL: int = 12, |
|
nH: int = 12, |
|
hSize: int = 768, |
|
edpt: float = 0.1, |
|
adpt: float = 0.1, |
|
rdpt: float = 0.1, |
|
odpt: float = 0.1, |
|
pt: str = "gpt", |
|
afn: str = "gelu", |
|
init: str = "pt", |
|
vSize: int = 40545, |
|
n_ctx: int = 31, |
|
n_vocab: int = 40545, |
|
return_acts: bool = True, |
|
return_probs: bool = False, |
|
**kwargs, |
|
): |
|
self.model = model |
|
self.nL = nL |
|
self.nH = nH |
|
self.hSize = hSize |
|
self.edpt = edpt |
|
self.adpt = adpt |
|
self.rdpt = rdpt |
|
self.odpt = odpt |
|
self.pt = pt |
|
self.afn = afn |
|
self.init = init |
|
self.vSize = vSize |
|
self.n_ctx = n_ctx |
|
self.n_vocab = n_vocab |
|
self.return_acts = return_acts |
|
self.return_probs = return_probs |
|
super().__init__(**kwargs) |
|
|
|
|
|
def parse_net_config(config): |
|
return { |
|
'model': config.model, |
|
'nL': config.nL, |
|
'nH': config.nH, |
|
'hSize': config.hSize, |
|
'edpt': config.edpt, |
|
'adpt': config.adpt, |
|
'rdpt': config.rdpt, |
|
'odpt': config.odpt, |
|
'pt': config.pt, |
|
'afn': config.afn, |
|
'init': config.init, |
|
'vSize': config.vSize, |
|
'n_ctx': config.n_ctx, |
|
} |