unography commited on
Commit
5eb7d19
1 Parent(s): 555a05e

Upload 2 files

Browse files
configuration_llama_lm_feats.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig
2
+
3
+
4
+ class LlamaWithFeatsEncoderConfig(LlamaConfig):
5
+ model_type = "llama_with_feats_encoder"
6
+
7
+ def __init__(self, feats_hidden_size=8, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.feats_hidden_size = feats_hidden_size
10
+
11
+ def to_dict(self):
12
+ """
13
+ Serializes this instance to a Python dictionary.
14
+ """
15
+ output = super().to_dict()
16
+ output["model_type"] = self.model_type
17
+ output["feats_hidden_size"] = self.feats_hidden_size
18
+ return output
model_llama_lm_feats.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import LlamaForCausalLM
3
+ from transformers.models.llama.modeling_llama import LlamaMLP
4
+
5
+ from .configuration_llama_lm_feats import LlamaWithFeatsEncoderConfig
6
+
7
+
8
+ class LlamaFeatsMLP(LlamaMLP):
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ self.gate_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False)
12
+ self.up_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False)
13
+
14
+
15
+ class LlamaWithFeatsForCausalLM(LlamaForCausalLM):
16
+ config_class = LlamaWithFeatsEncoderConfig
17
+
18
+ def __init__(self, config: LlamaWithFeatsEncoderConfig):
19
+ super().__init__(config)
20
+ self.feature_mlp = LlamaFeatsMLP(config)
21
+ self.post_init()
22
+
23
+ def forward(
24
+ self,
25
+ input_ids=None,
26
+ attention_mask=None,
27
+ meta_features=None,
28
+ position_ids=None,
29
+ past_key_values=None,
30
+ inputs_embeds=None,
31
+ labels=None,
32
+ use_cache=None,
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ return_dict=None,
36
+ ):
37
+
38
+ if inputs_embeds is None:
39
+ inputs_embeds = self.model.embed_tokens(input_ids)
40
+
41
+ if meta_features is not None:
42
+ feats_embeds = self.feature_mlp(meta_features)
43
+ inputs_embeds = inputs_embeds + feats_embeds
44
+
45
+ return super().forward(
46
+ attention_mask=attention_mask,
47
+ position_ids=position_ids,
48
+ past_key_values=past_key_values,
49
+ inputs_embeds=inputs_embeds,
50
+ labels=labels,
51
+ use_cache=use_cache,
52
+ output_attentions=output_attentions,
53
+ output_hidden_states=output_hidden_states,
54
+ return_dict=return_dict,
55
+ )