minor
Browse files
README.md
CHANGED
@@ -4,8 +4,8 @@ language:
|
|
4 |
- en
|
5 |
---
|
6 |
|
7 |
-
# Model Card for
|
8 |
-
The
|
9 |
|
10 |
<img src="takeaway.png" alt="avatar" width="300" height="200"/>
|
11 |
|
@@ -13,7 +13,7 @@ The average performance is evaluated using benchmarks from the OpenLLM Leaderboa
|
|
13 |
|
14 |
## Inference
|
15 |
|
16 |
-
Our code for accelerating
|
17 |
|
18 |
## Chat-Template
|
19 |
|
@@ -25,7 +25,7 @@ We take ChatML as our chat template:
|
|
25 |
|
26 |
## Allow Finetuning
|
27 |
|
28 |
-
As we merged the predictors for FFN neurons in models, you can finetune
|
29 |
|
30 |
## License
|
31 |
|
|
|
4 |
- en
|
5 |
---
|
6 |
|
7 |
+
# Model Card for TurboSparse-Mixtral
|
8 |
+
The TurboSparse-Mixtral Large Language Model (LLM) is an sparsified version of the Mixtral.
|
9 |
|
10 |
<img src="takeaway.png" alt="avatar" width="300" height="200"/>
|
11 |
|
|
|
13 |
|
14 |
## Inference
|
15 |
|
16 |
+
Our code for accelerating TurboSparse-Mixtral is currently being refined. Stay tuned! Now you can run this model like dense model.
|
17 |
|
18 |
## Chat-Template
|
19 |
|
|
|
25 |
|
26 |
## Allow Finetuning
|
27 |
|
28 |
+
As we merged the predictors for FFN neurons in models, you can finetune TurboSparse-Mixtral with any framework and algorithm.
|
29 |
|
30 |
## License
|
31 |
|
config.json
CHANGED
@@ -3,9 +3,9 @@
|
|
3 |
"TurboSparseMixtralForCausalLM"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
-
"AutoConfig": "
|
7 |
-
"AutoModel": "
|
8 |
-
"AutoModelForCausalLM": "
|
9 |
},
|
10 |
"attention_dropout": 0.0,
|
11 |
"bos_token_id": 1,
|
@@ -15,7 +15,7 @@
|
|
15 |
"initializer_range": 0.02,
|
16 |
"intermediate_size": 14336,
|
17 |
"max_position_embeddings": 32768,
|
18 |
-
"model_type": "
|
19 |
"num_attention_heads": 32,
|
20 |
"num_experts_per_tok": 2,
|
21 |
"num_hidden_layers": 32,
|
|
|
3 |
"TurboSparseMixtralForCausalLM"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_turbosparsemixtral.TurboSparseMixtralConfig",
|
7 |
+
"AutoModel": "modeling_turbosparsemixtral.TurboSparseMixtralForCausalLM",
|
8 |
+
"AutoModelForCausalLM": "modeling_turbosparsemixtral.TurboSparseMixtralForCausalLM"
|
9 |
},
|
10 |
"attention_dropout": 0.0,
|
11 |
"bos_token_id": 1,
|
|
|
15 |
"initializer_range": 0.02,
|
16 |
"intermediate_size": 14336,
|
17 |
"max_position_embeddings": 32768,
|
18 |
+
"model_type": "turbosparsemixtral",
|
19 |
"num_attention_heads": 32,
|
20 |
"num_experts_per_tok": 2,
|
21 |
"num_hidden_layers": 32,
|
configuration_supersparsemixtral.py → configuration_turbosparsemixtral.py
RENAMED
@@ -22,7 +22,7 @@ from transformers.utils import logging
|
|
22 |
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
25 |
-
class
|
26 |
r"""
|
27 |
This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
|
28 |
Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
@@ -106,7 +106,7 @@ class SuperSparseMixtralConfig(PretrainedConfig):
|
|
106 |
>>> configuration = model.config
|
107 |
```"""
|
108 |
|
109 |
-
model_type = "
|
110 |
keys_to_ignore_at_inference = ["past_key_values"]
|
111 |
|
112 |
def __init__(
|
|
|
22 |
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
25 |
+
class TurboSparseMixtralConfig(PretrainedConfig):
|
26 |
r"""
|
27 |
This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
|
28 |
Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
|
106 |
>>> configuration = model.config
|
107 |
```"""
|
108 |
|
109 |
+
model_type = "turbosparsemixtral"
|
110 |
keys_to_ignore_at_inference = ["past_key_values"]
|
111 |
|
112 |
def __init__(
|
modeling_supersparsemixtral.py → modeling_turbosparsemixtral.py
RENAMED
@@ -54,7 +54,7 @@ from transformers.utils import (
|
|
54 |
replace_return_docstrings,
|
55 |
is_torch_fx_available,
|
56 |
)
|
57 |
-
from .
|
58 |
@dataclass
|
59 |
class AttentionMaskConverter:
|
60 |
"""
|
@@ -634,7 +634,7 @@ def _get_unpad_data(attention_mask):
|
|
634 |
|
635 |
|
636 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
|
637 |
-
class
|
638 |
def __init__(self, hidden_size, eps=1e-6):
|
639 |
"""
|
640 |
MixtralRMSNorm is equivalent to T5LayerNorm
|
@@ -653,7 +653,7 @@ class SuperSparseMixtralRMSNorm(nn.Module):
|
|
653 |
|
654 |
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
|
655 |
# TODO @longjie no longer copied from Mistral after static cache
|
656 |
-
class
|
657 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
658 |
super().__init__()
|
659 |
|
@@ -742,13 +742,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
742 |
|
743 |
# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
|
744 |
# TODO @longjie no longer copied from Mistral after static cache
|
745 |
-
class
|
746 |
"""
|
747 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
748 |
and "Generating Long Sequences with Sparse Transformers".
|
749 |
"""
|
750 |
|
751 |
-
def __init__(self, config:
|
752 |
super().__init__()
|
753 |
self.config = config
|
754 |
self.layer_idx = layer_idx
|
@@ -779,7 +779,7 @@ class SuperSparseMixtralAttention(nn.Module):
|
|
779 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
780 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
781 |
|
782 |
-
self.rotary_emb =
|
783 |
self.head_dim,
|
784 |
max_position_embeddings=self.max_position_embeddings,
|
785 |
base=self.rope_theta,
|
@@ -867,7 +867,7 @@ class SuperSparseMixtralAttention(nn.Module):
|
|
867 |
|
868 |
# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
|
869 |
# TODO @longjie no longer copied from Mistral after static cache
|
870 |
-
class
|
871 |
"""
|
872 |
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
|
873 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
@@ -1154,7 +1154,7 @@ class SuperSparseMixtralFlashAttention2(SuperSparseMixtralAttention):
|
|
1154 |
|
1155 |
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
|
1156 |
# TODO @longjie no longer copied from Mistral after static cache
|
1157 |
-
class
|
1158 |
"""
|
1159 |
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
1160 |
`MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
@@ -1246,9 +1246,9 @@ class SuperSparseMixtralSdpaAttention(SuperSparseMixtralAttention):
|
|
1246 |
|
1247 |
|
1248 |
MIXTRAL_ATTENTION_CLASSES = {
|
1249 |
-
"eager":
|
1250 |
-
"flash_attention_2":
|
1251 |
-
"sdpa":
|
1252 |
}
|
1253 |
|
1254 |
class MLP(nn.Module):
|
@@ -1264,8 +1264,8 @@ class MLP(nn.Module):
|
|
1264 |
x = self.fc2(x)
|
1265 |
x = x.sigmoid()
|
1266 |
return x
|
1267 |
-
class
|
1268 |
-
def __init__(self, config:
|
1269 |
super().__init__()
|
1270 |
self.ffn_dim = config.intermediate_size
|
1271 |
self.hidden_dim = config.hidden_size
|
@@ -1288,7 +1288,7 @@ class SuperSparseMixtralBlockSparseTop2MLP(nn.Module):
|
|
1288 |
return current_hidden_states
|
1289 |
|
1290 |
|
1291 |
-
class
|
1292 |
"""
|
1293 |
This implementation is
|
1294 |
strictly equivalent to standard MoE with full capacity (no
|
@@ -1310,7 +1310,7 @@ class SuperSparseMixtralSparseMoeBlock(nn.Module):
|
|
1310 |
# gating
|
1311 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
1312 |
|
1313 |
-
self.experts = nn.ModuleList([
|
1314 |
|
1315 |
# Jitter parameters
|
1316 |
self.jitter_noise = config.router_jitter_noise
|
@@ -1356,16 +1356,16 @@ class SuperSparseMixtralSparseMoeBlock(nn.Module):
|
|
1356 |
return final_hidden_states, router_logits
|
1357 |
|
1358 |
|
1359 |
-
class
|
1360 |
-
def __init__(self, config:
|
1361 |
super().__init__()
|
1362 |
self.hidden_size = config.hidden_size
|
1363 |
|
1364 |
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
1365 |
|
1366 |
-
self.block_sparse_moe =
|
1367 |
-
self.input_layernorm =
|
1368 |
-
self.post_attention_layernorm =
|
1369 |
|
1370 |
def forward(
|
1371 |
self,
|
@@ -1451,11 +1451,11 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
1451 |
MIXTRAL_START_DOCSTRING,
|
1452 |
)
|
1453 |
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
|
1454 |
-
class
|
1455 |
-
config_class =
|
1456 |
base_model_prefix = "model"
|
1457 |
supports_gradient_checkpointing = True
|
1458 |
-
_no_split_modules = ["
|
1459 |
_skip_keys_device_placement = "past_key_values"
|
1460 |
_supports_flash_attn_2 = True
|
1461 |
_supports_sdpa = True
|
@@ -1546,7 +1546,7 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
|
|
1546 |
)
|
1547 |
# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
|
1548 |
# TODO @longjie no longer copied from Mistral after static cache
|
1549 |
-
class
|
1550 |
"""
|
1551 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
1552 |
|
@@ -1554,17 +1554,17 @@ class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
|
|
1554 |
config: MixtralConfig
|
1555 |
"""
|
1556 |
|
1557 |
-
def __init__(self, config:
|
1558 |
super().__init__(config)
|
1559 |
self.padding_idx = config.pad_token_id
|
1560 |
self.vocab_size = config.vocab_size
|
1561 |
|
1562 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1563 |
self.layers = nn.ModuleList(
|
1564 |
-
[
|
1565 |
)
|
1566 |
self._attn_implementation = config._attn_implementation
|
1567 |
-
self.norm =
|
1568 |
|
1569 |
self.gradient_checkpointing = False
|
1570 |
# Initialize weights and apply final processing
|
@@ -1741,12 +1741,12 @@ class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
|
|
1741 |
)
|
1742 |
|
1743 |
|
1744 |
-
class
|
1745 |
_tied_weights_keys = ["lm_head.weight"]
|
1746 |
|
1747 |
def __init__(self, config):
|
1748 |
super().__init__(config)
|
1749 |
-
self.model =
|
1750 |
self.vocab_size = config.vocab_size
|
1751 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1752 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
@@ -1974,11 +1974,11 @@ class SuperSparseMixtralForCausalLM(SuperSparseMixtralPreTrainedModel):
|
|
1974 |
MIXTRAL_START_DOCSTRING,
|
1975 |
)
|
1976 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
1977 |
-
class
|
1978 |
def __init__(self, config):
|
1979 |
super().__init__(config)
|
1980 |
self.num_labels = config.num_labels
|
1981 |
-
self.model =
|
1982 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1983 |
|
1984 |
# Initialize weights and apply final processing
|
@@ -2090,11 +2090,11 @@ class SuperSparseMixtralForSequenceClassification(SuperSparseMixtralPreTrainedMo
|
|
2090 |
MIXTRAL_START_DOCSTRING,
|
2091 |
)
|
2092 |
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
2093 |
-
class
|
2094 |
def __init__(self, config):
|
2095 |
super().__init__(config)
|
2096 |
self.num_labels = config.num_labels
|
2097 |
-
self.model =
|
2098 |
if getattr(config, "classifier_dropout", None) is not None:
|
2099 |
classifier_dropout = config.classifier_dropout
|
2100 |
elif getattr(config, "hidden_dropout", None) is not None:
|
|
|
54 |
replace_return_docstrings,
|
55 |
is_torch_fx_available,
|
56 |
)
|
57 |
+
from .configuration_turbosparsemixtral import TurboSparseMixtralConfig
|
58 |
@dataclass
|
59 |
class AttentionMaskConverter:
|
60 |
"""
|
|
|
634 |
|
635 |
|
636 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
|
637 |
+
class TurboSparseMixtralRMSNorm(nn.Module):
|
638 |
def __init__(self, hidden_size, eps=1e-6):
|
639 |
"""
|
640 |
MixtralRMSNorm is equivalent to T5LayerNorm
|
|
|
653 |
|
654 |
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
|
655 |
# TODO @longjie no longer copied from Mistral after static cache
|
656 |
+
class TurboSparseMixtralRotaryEmbedding(nn.Module):
|
657 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
658 |
super().__init__()
|
659 |
|
|
|
742 |
|
743 |
# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
|
744 |
# TODO @longjie no longer copied from Mistral after static cache
|
745 |
+
class TurboSparseMixtralAttention(nn.Module):
|
746 |
"""
|
747 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
748 |
and "Generating Long Sequences with Sparse Transformers".
|
749 |
"""
|
750 |
|
751 |
+
def __init__(self, config: TurboSparseMixtralConfig, layer_idx: Optional[int] = None):
|
752 |
super().__init__()
|
753 |
self.config = config
|
754 |
self.layer_idx = layer_idx
|
|
|
779 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
780 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
781 |
|
782 |
+
self.rotary_emb = TurboSparseMixtralRotaryEmbedding(
|
783 |
self.head_dim,
|
784 |
max_position_embeddings=self.max_position_embeddings,
|
785 |
base=self.rope_theta,
|
|
|
867 |
|
868 |
# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
|
869 |
# TODO @longjie no longer copied from Mistral after static cache
|
870 |
+
class TurboSparseMixtralFlashAttention2(TurboSparseMixtralAttention):
|
871 |
"""
|
872 |
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
|
873 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
|
1154 |
|
1155 |
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
|
1156 |
# TODO @longjie no longer copied from Mistral after static cache
|
1157 |
+
class TurboSparseMixtralSdpaAttention(TurboSparseMixtralAttention):
|
1158 |
"""
|
1159 |
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
1160 |
`MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
|
1246 |
|
1247 |
|
1248 |
MIXTRAL_ATTENTION_CLASSES = {
|
1249 |
+
"eager": TurboSparseMixtralAttention,
|
1250 |
+
"flash_attention_2": TurboSparseMixtralFlashAttention2,
|
1251 |
+
"sdpa": TurboSparseMixtralSdpaAttention,
|
1252 |
}
|
1253 |
|
1254 |
class MLP(nn.Module):
|
|
|
1264 |
x = self.fc2(x)
|
1265 |
x = x.sigmoid()
|
1266 |
return x
|
1267 |
+
class TurboSparseMixtralBlockSparseTop2MLP(nn.Module):
|
1268 |
+
def __init__(self, config: TurboSparseMixtralConfig, layer_id):
|
1269 |
super().__init__()
|
1270 |
self.ffn_dim = config.intermediate_size
|
1271 |
self.hidden_dim = config.hidden_size
|
|
|
1288 |
return current_hidden_states
|
1289 |
|
1290 |
|
1291 |
+
class TurboSparseMixtralSparseMoeBlock(nn.Module):
|
1292 |
"""
|
1293 |
This implementation is
|
1294 |
strictly equivalent to standard MoE with full capacity (no
|
|
|
1310 |
# gating
|
1311 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
1312 |
|
1313 |
+
self.experts = nn.ModuleList([TurboSparseMixtralBlockSparseTop2MLP(config, layer_id) for _ in range(self.num_experts)])
|
1314 |
|
1315 |
# Jitter parameters
|
1316 |
self.jitter_noise = config.router_jitter_noise
|
|
|
1356 |
return final_hidden_states, router_logits
|
1357 |
|
1358 |
|
1359 |
+
class TurboSparseMixtralDecoderLayer(nn.Module):
|
1360 |
+
def __init__(self, config: TurboSparseMixtralConfig, layer_idx: int):
|
1361 |
super().__init__()
|
1362 |
self.hidden_size = config.hidden_size
|
1363 |
|
1364 |
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
1365 |
|
1366 |
+
self.block_sparse_moe = TurboSparseMixtralSparseMoeBlock(config, layer_idx)
|
1367 |
+
self.input_layernorm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1368 |
+
self.post_attention_layernorm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1369 |
|
1370 |
def forward(
|
1371 |
self,
|
|
|
1451 |
MIXTRAL_START_DOCSTRING,
|
1452 |
)
|
1453 |
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
|
1454 |
+
class TurboSparseMixtralPreTrainedModel(PreTrainedModel):
|
1455 |
+
config_class = TurboSparseMixtralConfig
|
1456 |
base_model_prefix = "model"
|
1457 |
supports_gradient_checkpointing = True
|
1458 |
+
_no_split_modules = ["TurboSparseMixtralDecoderLayer"]
|
1459 |
_skip_keys_device_placement = "past_key_values"
|
1460 |
_supports_flash_attn_2 = True
|
1461 |
_supports_sdpa = True
|
|
|
1546 |
)
|
1547 |
# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
|
1548 |
# TODO @longjie no longer copied from Mistral after static cache
|
1549 |
+
class TurboSparseMixtralModel(TurboSparseMixtralPreTrainedModel):
|
1550 |
"""
|
1551 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
1552 |
|
|
|
1554 |
config: MixtralConfig
|
1555 |
"""
|
1556 |
|
1557 |
+
def __init__(self, config: TurboSparseMixtralConfig):
|
1558 |
super().__init__(config)
|
1559 |
self.padding_idx = config.pad_token_id
|
1560 |
self.vocab_size = config.vocab_size
|
1561 |
|
1562 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1563 |
self.layers = nn.ModuleList(
|
1564 |
+
[TurboSparseMixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
1565 |
)
|
1566 |
self._attn_implementation = config._attn_implementation
|
1567 |
+
self.norm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1568 |
|
1569 |
self.gradient_checkpointing = False
|
1570 |
# Initialize weights and apply final processing
|
|
|
1741 |
)
|
1742 |
|
1743 |
|
1744 |
+
class TurboSparseMixtralForCausalLM(TurboSparseMixtralPreTrainedModel):
|
1745 |
_tied_weights_keys = ["lm_head.weight"]
|
1746 |
|
1747 |
def __init__(self, config):
|
1748 |
super().__init__(config)
|
1749 |
+
self.model = TurboSparseMixtralModel(config)
|
1750 |
self.vocab_size = config.vocab_size
|
1751 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1752 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
|
1974 |
MIXTRAL_START_DOCSTRING,
|
1975 |
)
|
1976 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
1977 |
+
class TurboSparseMixtralForSequenceClassification(TurboSparseMixtralPreTrainedModel):
|
1978 |
def __init__(self, config):
|
1979 |
super().__init__(config)
|
1980 |
self.num_labels = config.num_labels
|
1981 |
+
self.model = TurboSparseMixtralModel(config)
|
1982 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1983 |
|
1984 |
# Initialize weights and apply final processing
|
|
|
2090 |
MIXTRAL_START_DOCSTRING,
|
2091 |
)
|
2092 |
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
2093 |
+
class TurboSparseMixtralForTokenClassification(TurboSparseMixtralPreTrainedModel):
|
2094 |
def __init__(self, config):
|
2095 |
super().__init__(config)
|
2096 |
self.num_labels = config.num_labels
|
2097 |
+
self.model = TurboSparseMixtralModel(config)
|
2098 |
if getattr(config, "classifier_dropout", None) is not None:
|
2099 |
classifier_dropout = config.classifier_dropout
|
2100 |
elif getattr(config, "hidden_dropout", None) is not None:
|