codify_3b_multi / configuration_codify.py
smallcloudteam's picture
Upload config
5cc155f
raw
history blame
5.45 kB
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from packaging import version
from transformers import is_torch_available
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, TensorType
from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
from transformers.utils import logging
logger = logging.get_logger(__name__)
CODIFY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"smallcloudai/codify_medium_multi": "https://huggingface.co/smallcloudai/codify_medium_multi/blob/main/config.json",
"smallcloudai/codify_3b_multi": "https://huggingface.co/smallcloudai/codify_3b_multi/blob/main/config.json",
}
class CodifyConfig(PretrainedConfig):
model_type = "codify"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "L",
"num_attention_heads": "attn_heads",
"hidden_size": "E",
}
def __init__(
self,
vocab_size=51305,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
mlp_mult=4,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.mlp_mult = mlp_mult
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings, **kwargs)
class CodifyOnnxConfig(OnnxConfigWithPast):
torch_onnx_minimum_version = version.parse("1.12")
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
# BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def num_layers(self) -> int:
return self._config.num_hidden_layers
@property
def num_attention_heads(self) -> int:
return self._config.n_head
@property
def atol_for_validation(self) -> float:
return 1e-3
def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizer",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
head_dim = self._config.hidden_size // self.num_attention_heads
past_key_shape = (
batch * self.num_attention_heads,
head_dim,
past_key_values_length,
)
past_value_shape = (
batch * self.num_attention_heads,
past_key_values_length,
head_dim,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13
from transformers import AutoConfig
AutoConfig.register(CodifyConfig.model_type, CodifyConfig)