Upload 2 files
Browse files- configuration_llama_lm_feats.py +18 -0
- model_llama_lm_feats.py +55 -0
configuration_llama_lm_feats.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LlamaConfig
|
2 |
+
|
3 |
+
|
4 |
+
class LlamaWithFeatsEncoderConfig(LlamaConfig):
|
5 |
+
model_type = "llama_with_feats_encoder"
|
6 |
+
|
7 |
+
def __init__(self, feats_hidden_size=8, **kwargs):
|
8 |
+
super().__init__(**kwargs)
|
9 |
+
self.feats_hidden_size = feats_hidden_size
|
10 |
+
|
11 |
+
def to_dict(self):
|
12 |
+
"""
|
13 |
+
Serializes this instance to a Python dictionary.
|
14 |
+
"""
|
15 |
+
output = super().to_dict()
|
16 |
+
output["model_type"] = self.model_type
|
17 |
+
output["feats_hidden_size"] = self.feats_hidden_size
|
18 |
+
return output
|
model_llama_lm_feats.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import LlamaForCausalLM
|
3 |
+
from transformers.models.llama.modeling_llama import LlamaMLP
|
4 |
+
|
5 |
+
from .configuration_llama_lm_feats import LlamaWithFeatsEncoderConfig
|
6 |
+
|
7 |
+
|
8 |
+
class LlamaFeatsMLP(LlamaMLP):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__(config)
|
11 |
+
self.gate_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False)
|
12 |
+
self.up_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False)
|
13 |
+
|
14 |
+
|
15 |
+
class LlamaWithFeatsForCausalLM(LlamaForCausalLM):
|
16 |
+
config_class = LlamaWithFeatsEncoderConfig
|
17 |
+
|
18 |
+
def __init__(self, config: LlamaWithFeatsEncoderConfig):
|
19 |
+
super().__init__(config)
|
20 |
+
self.feature_mlp = LlamaFeatsMLP(config)
|
21 |
+
self.post_init()
|
22 |
+
|
23 |
+
def forward(
|
24 |
+
self,
|
25 |
+
input_ids=None,
|
26 |
+
attention_mask=None,
|
27 |
+
meta_features=None,
|
28 |
+
position_ids=None,
|
29 |
+
past_key_values=None,
|
30 |
+
inputs_embeds=None,
|
31 |
+
labels=None,
|
32 |
+
use_cache=None,
|
33 |
+
output_attentions=None,
|
34 |
+
output_hidden_states=None,
|
35 |
+
return_dict=None,
|
36 |
+
):
|
37 |
+
|
38 |
+
if inputs_embeds is None:
|
39 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
40 |
+
|
41 |
+
if meta_features is not None:
|
42 |
+
feats_embeds = self.feature_mlp(meta_features)
|
43 |
+
inputs_embeds = inputs_embeds + feats_embeds
|
44 |
+
|
45 |
+
return super().forward(
|
46 |
+
attention_mask=attention_mask,
|
47 |
+
position_ids=position_ids,
|
48 |
+
past_key_values=past_key_values,
|
49 |
+
inputs_embeds=inputs_embeds,
|
50 |
+
labels=labels,
|
51 |
+
use_cache=use_cache,
|
52 |
+
output_attentions=output_attentions,
|
53 |
+
output_hidden_states=output_hidden_states,
|
54 |
+
return_dict=return_dict,
|
55 |
+
)
|