tommymarto commited on
Commit
35a5cdb
1 Parent(s): 4e2cab1

added studentbert config and modeling files

Browse files
Files changed (3) hide show
  1. config.json +9 -5
  2. configuration_mcqbert.py +10 -0
  3. modeling_mcqbert.py +48 -0
config.json CHANGED
@@ -1,22 +1,26 @@
1
  {
2
- "_name_or_path": "tommymarto/LernnaviBERT_mcqbert3_correct_answers_4096",
3
- "architectures": [
4
- "MCQBert3"
5
- ],
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
 
8
  "hidden_act": "gelu",
9
  "hidden_dropout_prob": 0.1,
10
  "hidden_size": 768,
11
  "initializer_range": 0.02,
 
12
  "intermediate_size": 3072,
13
  "layer_norm_eps": 1e-12,
14
  "max_position_embeddings": 512,
15
- "model_type": "bert",
16
  "num_attention_heads": 12,
17
  "num_hidden_layers": 12,
18
  "pad_token_id": 0,
19
  "position_embedding_type": "absolute",
 
20
  "torch_dtype": "float32",
21
  "transformers_version": "4.37.2",
22
  "type_vocab_size": 2,
 
1
  {
2
+ "_name_or_path": "epfl-ml4ed/MCQStudentBertSum",
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_mcqbert.MCQBertConfig",
5
+ "AutoModel": "modeling_mcqbert.MCQStudentBert"
6
+ },
7
  "attention_probs_dropout_prob": 0.1,
8
  "classifier_dropout": null,
9
+ "cls_hidden_size": 256,
10
  "hidden_act": "gelu",
11
  "hidden_dropout_prob": 0.1,
12
  "hidden_size": 768,
13
  "initializer_range": 0.02,
14
+ "integration_strategy": "sum",
15
  "intermediate_size": 3072,
16
  "layer_norm_eps": 1e-12,
17
  "max_position_embeddings": 512,
18
+ "model_type": "mcqbert",
19
  "num_attention_heads": 12,
20
  "num_hidden_layers": 12,
21
  "pad_token_id": 0,
22
  "position_embedding_type": "absolute",
23
+ "student_embedding_size": 4096,
24
  "torch_dtype": "float32",
25
  "transformers_version": "4.37.2",
26
  "type_vocab_size": 2,
configuration_mcqbert.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+ class MCQBertConfig(BertConfig):
4
+ model_type = "mcqbert"
5
+
6
+ def __init__(self, integration_strategy=None, student_embedding_size=4096, cls_hidden_size=256, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.integration_strategy = integration_strategy
9
+ self.student_embedding_size = student_embedding_size
10
+ self.cls_hidden_size = cls_hidden_size
modeling_mcqbert.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel
2
+ import torch
3
+
4
+ from .configuration_mcqbert import MCQBertConfig
5
+
6
+ class MCQStudentBert(BertModel):
7
+ config_class = MCQBertConfig
8
+
9
+ def __init__(self, config: MCQBertConfig):
10
+ super().__init__(config)
11
+
12
+ if config.integration_strategy is not None:
13
+ self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size)
14
+
15
+ cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
16
+ cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier
17
+
18
+ self.classifier = torch.nn.Sequential(
19
+ torch.nn.Linear(cls_input_dim, config.cls_hidden_size),
20
+ torch.nn.ReLU(),
21
+ torch.nn.Linear(config.cls_hidden_size, 1)
22
+ )
23
+
24
+ def forward(self, input_ids, student_embeddings=None):
25
+ if self.config.integration_strategy is None:
26
+ # don't consider embeddings is no integration strategy (MCQBert)
27
+ student_embeddings = torch.zeros(self.config.student_embedding_layer)
28
+
29
+ input_embeddings = self.embeddings(input_ids)
30
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
31
+ output = super().forward(inputs_embeds = combined_embeddings)
32
+ return self.classifier(output.last_hidden_state[:, 0, :])
33
+
34
+ elif self.config.integration_strategy == "cat":
35
+ # MCQStudentBertCat
36
+ output = super().forward(input_ids)
37
+ output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings)), dim = 1)
38
+ return self.classifier(output_with_student_embedding)
39
+
40
+ elif self.config.integration_strategy == "sum":
41
+ # MCQStudentBertSum
42
+ input_embeddings = self.embeddings(input_ids)
43
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
44
+ output = super().forward(inputs_embeds = combined_embeddings)
45
+ return self.classifier(output.last_hidden_state[:, 0, :])
46
+
47
+ else:
48
+ raise ValueError(f"{self.config.integration_strategy} is not a known integration_strategy")