fine
Browse files- CustomBERTModel.py +33 -0
- Untitled.ipynb +0 -0
- __pycache__/metrics.cpython-312.pyc +0 -0
- __pycache__/recalibration.cpython-312.pyc +0 -0
- __pycache__/visualization.cpython-312.pyc +0 -0
- app.py +48 -0
- data_preprocessor.py +170 -0
- hint_fine_tuning.py +382 -0
- main.py +322 -0
- metrics.py +149 -0
- new_fine_tuning/README.md +197 -0
- new_fine_tuning/__pycache__/metrics.cpython-312.pyc +0 -0
- new_fine_tuning/__pycache__/recalibration.cpython-312.pyc +0 -0
- new_fine_tuning/__pycache__/visualization.cpython-312.pyc +0 -0
- new_hint_fine_tuned.py +131 -0
- new_test_saved_finetuned_model.py +613 -0
- plot.png +0 -0
- prepare_pretraining_input_vocab_file.py +0 -0
- ratio_proportion_change3_2223/sch_largest_100-coded/pretraining/vocab.txt +34 -0
- recalibration.py +82 -0
- src/__pycache__/attention.cpython-312.pyc +0 -0
- src/__pycache__/bert.cpython-312.pyc +0 -0
- src/__pycache__/classifier_model.cpython-312.pyc +0 -0
- src/__pycache__/dataset.cpython-312.pyc +0 -0
- src/__pycache__/embedding.cpython-312.pyc +0 -0
- src/__pycache__/seq_model.cpython-312.pyc +0 -0
- src/__pycache__/transformer.cpython-312.pyc +0 -0
- src/__pycache__/transformer_component.cpython-312.pyc +0 -0
- src/__pycache__/vocab.cpython-312.pyc +0 -0
- src/attention.py +21 -1
- src/bert.py +35 -0
- src/classifier_model.py +52 -1
- src/dataset.py +385 -0
- src/pretrainer.py +713 -0
- src/reference_code/bert_reference_code.py +1622 -0
- src/reference_code/evaluate_embeddings.py +136 -0
- src/reference_code/metrics.py +149 -0
- src/reference_code/pretrainer-old.py +696 -0
- src/reference_code/test.py +493 -0
- src/reference_code/utils.py +369 -0
- src/reference_code/visualization.py +78 -0
- src/seq_model.py +15 -0
- src/transformer.py +11 -0
- src/vocab.py +17 -0
- test.py +8 -0
- test.txt +0 -0
- test_hint_fine_tuned.py +45 -0
- test_saved_model.py +234 -0
- visualization.py +78 -0
CustomBERTModel.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.bert import BERT
|
4 |
+
|
5 |
+
class CustomBERTModel(nn.Module):
|
6 |
+
def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
7 |
+
super(CustomBERTModel, self).__init__()
|
8 |
+
hidden_size = 768
|
9 |
+
self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=4, attn_heads=8, dropout=0.1)
|
10 |
+
|
11 |
+
# Load the pre-trained model's state_dict
|
12 |
+
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
13 |
+
if isinstance(checkpoint, dict):
|
14 |
+
self.bert.load_state_dict(checkpoint)
|
15 |
+
else:
|
16 |
+
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
|
17 |
+
|
18 |
+
# Fully connected layer with input size 768 (matching BERT hidden size)
|
19 |
+
self.fc = nn.Linear(hidden_size, output_dim)
|
20 |
+
|
21 |
+
def forward(self, sequence, segment_info):
|
22 |
+
sequence = sequence.to(next(self.parameters()).device)
|
23 |
+
segment_info = segment_info.to(sequence.device)
|
24 |
+
|
25 |
+
x = self.bert(sequence, segment_info)
|
26 |
+
print(f"BERT output shape: {x.shape}")
|
27 |
+
|
28 |
+
cls_embeddings = x[:, 0] # Extract CLS token embeddings
|
29 |
+
print(f"CLS Embeddings shape: {cls_embeddings.shape}")
|
30 |
+
|
31 |
+
logits = self.fc(cls_embeddings) # Pass tensor of size (batch_size, 768) to the fully connected layer
|
32 |
+
|
33 |
+
return logits
|
Untitled.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
__pycache__/metrics.cpython-312.pyc
ADDED
Binary file (9.14 kB). View file
|
|
__pycache__/recalibration.cpython-312.pyc
ADDED
Binary file (5.49 kB). View file
|
|
__pycache__/visualization.cpython-312.pyc
ADDED
Binary file (5.27 kB). View file
|
|
app.py
CHANGED
@@ -101,15 +101,48 @@ import shutil
|
|
101 |
import matplotlib.pyplot as plt
|
102 |
from sklearn.metrics import roc_curve, auc
|
103 |
# Define the function to process the input file and model selection
|
|
|
|
|
|
|
104 |
def process_file(file,label, model_name):
|
|
|
105 |
with open(file.name, 'r') as f:
|
106 |
content = f.read()
|
107 |
saved_test_dataset = "train.txt"
|
108 |
saved_test_label = "train_label.txt"
|
|
|
|
|
|
|
|
|
109 |
|
110 |
# Save the uploaded file content to a specified location
|
111 |
shutil.copyfile(file.name, saved_test_dataset)
|
112 |
shutil.copyfile(label.name, saved_test_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
# For demonstration purposes, we'll just return the content with the selected model name
|
114 |
if(model_name=="FS"):
|
115 |
checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
|
@@ -126,6 +159,7 @@ def process_file(file,label, model_name):
|
|
126 |
subprocess.run(["python", "src/test_saved_model.py",
|
127 |
"--finetuned_bert_checkpoint",checkpoint
|
128 |
])
|
|
|
129 |
result = {}
|
130 |
with open("result.txt", 'r') as file:
|
131 |
for line in file:
|
@@ -160,7 +194,11 @@ def process_file(file,label, model_name):
|
|
160 |
return text_output,plot_path
|
161 |
|
162 |
# List of models for the dropdown menu
|
|
|
|
|
|
|
163 |
models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"]
|
|
|
164 |
|
165 |
# Create the Gradio interface
|
166 |
with gr.Blocks(css="""
|
@@ -350,15 +388,25 @@ tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
|
|
350 |
with gr.Row():
|
351 |
file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
|
352 |
label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
model_dropdown = gr.Dropdown(choices=models, label="Select Model", elem_classes="dropdown-menu")
|
|
|
355 |
|
356 |
with gr.Row():
|
357 |
output_text = gr.Textbox(label="Output Text")
|
358 |
output_image = gr.Image(label="Output Plot")
|
359 |
|
360 |
btn = gr.Button("Submit")
|
|
|
|
|
|
|
361 |
btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image])
|
|
|
362 |
|
363 |
# Launch the app
|
364 |
demo.launch()
|
|
|
101 |
import matplotlib.pyplot as plt
|
102 |
from sklearn.metrics import roc_curve, auc
|
103 |
# Define the function to process the input file and model selection
|
104 |
+
<<<<<<< HEAD
|
105 |
+
def process_file(file,label,info, model_name):
|
106 |
+
=======
|
107 |
def process_file(file,label, model_name):
|
108 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
109 |
with open(file.name, 'r') as f:
|
110 |
content = f.read()
|
111 |
saved_test_dataset = "train.txt"
|
112 |
saved_test_label = "train_label.txt"
|
113 |
+
<<<<<<< HEAD
|
114 |
+
saved_train_info="train_info.txt"
|
115 |
+
=======
|
116 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
117 |
|
118 |
# Save the uploaded file content to a specified location
|
119 |
shutil.copyfile(file.name, saved_test_dataset)
|
120 |
shutil.copyfile(label.name, saved_test_label)
|
121 |
+
<<<<<<< HEAD
|
122 |
+
shutil.copyfile(info.name, saved_train_info)
|
123 |
+
# For demonstration purposes, we'll just return the content with the selected model name
|
124 |
+
# if(model_name=="highGRschool10"):
|
125 |
+
# checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
|
126 |
+
# elif(model_name=="lowGRschoolAll"):
|
127 |
+
# checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14"
|
128 |
+
# elif(model_name=="fullTest"):
|
129 |
+
# checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48"
|
130 |
+
# else:
|
131 |
+
# checkpoint=None
|
132 |
+
|
133 |
+
# print(checkpoint)
|
134 |
+
subprocess.run([
|
135 |
+
"python", "new_test_saved_finetuned_model.py",
|
136 |
+
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
|
137 |
+
"-finetune_task", model_name,
|
138 |
+
"-test_dataset_path","../../../../train.txt",
|
139 |
+
# "-test_label_path","../../../../train_label.txt",
|
140 |
+
"-finetuned_bert_classifier_checkpoint",
|
141 |
+
"ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
|
142 |
+
"-e",str(1),
|
143 |
+
"-b",str(5)
|
144 |
+
], shell=True)
|
145 |
+
=======
|
146 |
# For demonstration purposes, we'll just return the content with the selected model name
|
147 |
if(model_name=="FS"):
|
148 |
checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
|
|
|
159 |
subprocess.run(["python", "src/test_saved_model.py",
|
160 |
"--finetuned_bert_checkpoint",checkpoint
|
161 |
])
|
162 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
163 |
result = {}
|
164 |
with open("result.txt", 'r') as file:
|
165 |
for line in file:
|
|
|
194 |
return text_output,plot_path
|
195 |
|
196 |
# List of models for the dropdown menu
|
197 |
+
<<<<<<< HEAD
|
198 |
+
models = ["highGRschool10", "lowGRschoolAll", "fullTest"]
|
199 |
+
=======
|
200 |
models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"]
|
201 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
202 |
|
203 |
# Create the Gradio interface
|
204 |
with gr.Blocks(css="""
|
|
|
388 |
with gr.Row():
|
389 |
file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
|
390 |
label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
|
391 |
+
<<<<<<< HEAD
|
392 |
+
info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
|
393 |
+
|
394 |
+
model_dropdown = gr.Dropdown(choices=models, label="Select Finetune Task", elem_classes="dropdown-menu")
|
395 |
+
=======
|
396 |
|
397 |
model_dropdown = gr.Dropdown(choices=models, label="Select Model", elem_classes="dropdown-menu")
|
398 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
399 |
|
400 |
with gr.Row():
|
401 |
output_text = gr.Textbox(label="Output Text")
|
402 |
output_image = gr.Image(label="Output Plot")
|
403 |
|
404 |
btn = gr.Button("Submit")
|
405 |
+
<<<<<<< HEAD
|
406 |
+
btn.click(fn=process_file, inputs=[file_input,label_input,info_input, model_dropdown], outputs=[output_text,output_image])
|
407 |
+
=======
|
408 |
btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image])
|
409 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
410 |
|
411 |
# Launch the app
|
412 |
demo.launch()
|
data_preprocessor.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
import sys
|
5 |
+
|
6 |
+
class DataPreprocessor:
|
7 |
+
def __init__(self, input_file_path):
|
8 |
+
self.input_file_path = input_file_path
|
9 |
+
self.unique_students = None
|
10 |
+
self.unique_problems = None
|
11 |
+
self.unique_prob_hierarchy = None
|
12 |
+
self.unique_steps = None
|
13 |
+
self.unique_kcs = None
|
14 |
+
|
15 |
+
def analyze_dataset(self):
|
16 |
+
file_iterator = self.load_file_iterator()
|
17 |
+
|
18 |
+
start_time = time.time()
|
19 |
+
self.unique_students = {"st"}
|
20 |
+
self.unique_problems = {"pr"}
|
21 |
+
self.unique_prob_hierarchy = {"ph"}
|
22 |
+
self.unique_kcs = {"kc"}
|
23 |
+
for chunk_data in file_iterator:
|
24 |
+
for student_id, std_groups in chunk_data.groupby('Anon Student Id'):
|
25 |
+
self.unique_students.update({student_id})
|
26 |
+
prob_hierarchy = std_groups.groupby('Level (Workspace Id)')
|
27 |
+
for hierarchy, hierarchy_groups in prob_hierarchy:
|
28 |
+
self.unique_prob_hierarchy.update({hierarchy})
|
29 |
+
prob_name = hierarchy_groups.groupby('Problem Name')
|
30 |
+
for problem_name, prob_name_groups in prob_name:
|
31 |
+
self.unique_problems.update({problem_name})
|
32 |
+
sub_skills = prob_name_groups['KC Model(MATHia)']
|
33 |
+
for a in sub_skills:
|
34 |
+
if str(a) != "nan":
|
35 |
+
temp = a.split("~~")
|
36 |
+
for kc in temp:
|
37 |
+
self.unique_kcs.update({kc})
|
38 |
+
self.unique_students.remove("st")
|
39 |
+
self.unique_problems.remove("pr")
|
40 |
+
self.unique_prob_hierarchy.remove("ph")
|
41 |
+
self.unique_kcs.remove("kc")
|
42 |
+
end_time = time.time()
|
43 |
+
print("Time Taken to analyze dataset = ", end_time - start_time)
|
44 |
+
print("Length of unique students->", len(self.unique_students))
|
45 |
+
print("Length of unique problems->", len(self.unique_problems))
|
46 |
+
print("Length of unique problem hierarchy->", len(self.unique_prob_hierarchy))
|
47 |
+
print("Length of Unique Knowledge components ->", len(self.unique_kcs))
|
48 |
+
|
49 |
+
def analyze_dataset_by_section(self, workspace_name):
|
50 |
+
file_iterator = self.load_file_iterator()
|
51 |
+
|
52 |
+
start_time = time.time()
|
53 |
+
self.unique_students = {"st"}
|
54 |
+
self.unique_problems = {"pr"}
|
55 |
+
self.unique_prob_hierarchy = {"ph"}
|
56 |
+
self.unique_steps = {"s"}
|
57 |
+
self.unique_kcs = {"kc"}
|
58 |
+
# with open("workspace_info.txt", 'a') as f:
|
59 |
+
# sys.stdout = f
|
60 |
+
for chunk_data in file_iterator:
|
61 |
+
for student_id, std_groups in chunk_data.groupby('Anon Student Id'):
|
62 |
+
prob_hierarchy = std_groups.groupby('Level (Workspace Id)')
|
63 |
+
for hierarchy, hierarchy_groups in prob_hierarchy:
|
64 |
+
if workspace_name == hierarchy:
|
65 |
+
# print("Workspace : ", hierarchy)
|
66 |
+
self.unique_students.update({student_id})
|
67 |
+
self.unique_prob_hierarchy.update({hierarchy})
|
68 |
+
prob_name = hierarchy_groups.groupby('Problem Name')
|
69 |
+
for problem_name, prob_name_groups in prob_name:
|
70 |
+
self.unique_problems.update({problem_name})
|
71 |
+
step_names = prob_name_groups['Step Name']
|
72 |
+
sub_skills = prob_name_groups['KC Model(MATHia)']
|
73 |
+
for step in step_names:
|
74 |
+
if str(step) != "nan":
|
75 |
+
self.unique_steps.update({step})
|
76 |
+
for a in sub_skills:
|
77 |
+
if str(a) != "nan":
|
78 |
+
temp = a.split("~~")
|
79 |
+
for kc in temp:
|
80 |
+
self.unique_kcs.update({kc})
|
81 |
+
self.unique_problems.remove("pr")
|
82 |
+
self.unique_prob_hierarchy.remove("ph")
|
83 |
+
self.unique_steps.remove("s")
|
84 |
+
self.unique_kcs.remove("kc")
|
85 |
+
end_time = time.time()
|
86 |
+
print("Time Taken to analyze dataset = ", end_time - start_time)
|
87 |
+
print("Workspace-> ",workspace_name)
|
88 |
+
print("Length of unique students->", len(self.unique_students))
|
89 |
+
print("Length of unique problems->", len(self.unique_problems))
|
90 |
+
print("Length of unique problem hierarchy->", len(self.unique_prob_hierarchy))
|
91 |
+
print("Length of unique step names ->", len(self.unique_steps))
|
92 |
+
print("Length of unique knowledge components ->", len(self.unique_kcs))
|
93 |
+
# f.close()
|
94 |
+
# sys.stdout = sys.__stdout__
|
95 |
+
|
96 |
+
def analyze_dataset_by_school(self, workspace_name, school_id=None):
|
97 |
+
file_iterator = self.load_file_iterator(sep=",")
|
98 |
+
|
99 |
+
start_time = time.time()
|
100 |
+
self.unique_schools = set()
|
101 |
+
self.unique_class = set()
|
102 |
+
self.unique_students = set()
|
103 |
+
self.unique_problems = set()
|
104 |
+
self.unique_steps = set()
|
105 |
+
self.unique_kcs = set()
|
106 |
+
self.unique_actions = set()
|
107 |
+
self.unique_outcomes = set()
|
108 |
+
self.unique_new_steps_w_action_attempt = set()
|
109 |
+
self.unique_new_steps_w_kcs = set()
|
110 |
+
self.unique_new_steps_w_action_attempt_kcs = set()
|
111 |
+
|
112 |
+
for chunk_data in file_iterator:
|
113 |
+
for school, school_group in chunk_data.groupby('CF (Anon School Id)'):
|
114 |
+
# if school and school == school_id:
|
115 |
+
self.unique_schools.add(school)
|
116 |
+
for class_id, class_group in school_group.groupby('CF (Anon Class Id)'):
|
117 |
+
self.unique_class.add(class_id)
|
118 |
+
for student_id, std_group in class_group.groupby('Anon Student Id'):
|
119 |
+
self.unique_students.add(student_id)
|
120 |
+
for prob, prob_group in std_group.groupby('Problem Name'):
|
121 |
+
self.unique_problems.add(prob)
|
122 |
+
|
123 |
+
step_names = set(prob_group['Step Name'])
|
124 |
+
sub_skills = set(prob_group['KC Model(MATHia)'])
|
125 |
+
actions = set(prob_group['Action'])
|
126 |
+
outcomes = set(prob_group['Outcome'])
|
127 |
+
|
128 |
+
self.unique_steps.update(step_names)
|
129 |
+
self.unique_kcs.update(sub_skills)
|
130 |
+
self.unique_actions.update(actions)
|
131 |
+
self.unique_outcomes.update(outcomes)
|
132 |
+
|
133 |
+
for step in step_names:
|
134 |
+
if pd.isna(step):
|
135 |
+
step_group = prob_group[pd.isna(prob_group['Step Name'])]
|
136 |
+
else:
|
137 |
+
step_group = prob_group[prob_group['Step Name']==step]
|
138 |
+
|
139 |
+
for kc in set(step_group['KC Model(MATHia)']):
|
140 |
+
new_step = f"{step}:{kc}"
|
141 |
+
self.unique_new_steps_w_kcs.add(new_step)
|
142 |
+
|
143 |
+
for action, action_group in step_group.groupby('Action'):
|
144 |
+
for attempt, attempt_group in action_group.groupby('Attempt At Step'):
|
145 |
+
new_step = f"{step}:{action}:{attempt}"
|
146 |
+
self.unique_new_steps_w_action_attempt.add(new_step)
|
147 |
+
|
148 |
+
for kc in set(attempt_group["KC Model(MATHia)"]):
|
149 |
+
new_step = f"{step}:{action}:{attempt}:{kc}"
|
150 |
+
self.unique_new_steps_w_action_attempt_kcs.add(new_step)
|
151 |
+
|
152 |
+
|
153 |
+
end_time = time.time()
|
154 |
+
print("Time Taken to analyze dataset = ", end_time - start_time)
|
155 |
+
print("Workspace-> ",workspace_name)
|
156 |
+
print("Length of unique students->", len(self.unique_students))
|
157 |
+
print("Length of unique problems->", len(self.unique_problems))
|
158 |
+
print("Length of unique classes->", len(self.unique_class))
|
159 |
+
print("Length of unique step names ->", len(self.unique_steps))
|
160 |
+
print("Length of unique knowledge components ->", len(self.unique_kcs))
|
161 |
+
print("Length of unique actions ->", len(self.unique_actions))
|
162 |
+
print("Length of unique outcomes ->", len(self.unique_outcomes))
|
163 |
+
print("Length of unique new step names with actions and attempts ->", len(self.unique_new_steps_w_action_attempt))
|
164 |
+
print("Length of unique new step names with actions, attempts and kcs ->", len(self.unique_new_steps_w_action_attempt_kcs))
|
165 |
+
print("Length of unique new step names with kcs ->", len(self.unique_new_steps_w_kcs))
|
166 |
+
|
167 |
+
def load_file_iterator(self, sep="\t"):
|
168 |
+
chunk_iterator = pd.read_csv(self.input_file_path, sep=sep, header=0, iterator=True, chunksize=1000000)
|
169 |
+
return chunk_iterator
|
170 |
+
|
hint_fine_tuning.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.utils.data import DataLoader, random_split, TensorDataset
|
7 |
+
from src.dataset import TokenizerDataset
|
8 |
+
from src.bert import BERT
|
9 |
+
from src.pretrainer import BERTFineTuneTrainer1
|
10 |
+
from src.vocab import Vocab
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
|
14 |
+
# class CustomBERTModel(nn.Module):
|
15 |
+
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
16 |
+
# super(CustomBERTModel, self).__init__()
|
17 |
+
# hidden_size = 768
|
18 |
+
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
|
19 |
+
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
20 |
+
# if isinstance(checkpoint, dict):
|
21 |
+
# self.bert.load_state_dict(checkpoint)
|
22 |
+
# elif isinstance(checkpoint, BERT):
|
23 |
+
# self.bert = checkpoint
|
24 |
+
# else:
|
25 |
+
# raise TypeError(f"Expected state_dict or BERT instance, got {type(checkpoint)} instead.")
|
26 |
+
# self.fc = nn.Linear(hidden_size, output_dim)
|
27 |
+
|
28 |
+
# def forward(self, sequence, segment_info):
|
29 |
+
# sequence = sequence.to(next(self.parameters()).device)
|
30 |
+
# segment_info = segment_info.to(sequence.device)
|
31 |
+
|
32 |
+
# if sequence.size(0) == 0 or sequence.size(1) == 0:
|
33 |
+
# raise ValueError("Input sequence tensor has 0 elements. Check data preprocessing.")
|
34 |
+
|
35 |
+
# x = self.bert(sequence, segment_info)
|
36 |
+
# print(f"BERT output shape: {x.shape}")
|
37 |
+
|
38 |
+
# if x.size(0) == 0 or x.size(1) == 0:
|
39 |
+
# raise ValueError("BERT output tensor has 0 elements. Check input dimensions.")
|
40 |
+
|
41 |
+
# cls_embeddings = x[:, 0]
|
42 |
+
# logits = self.fc(cls_embeddings)
|
43 |
+
# return logits
|
44 |
+
|
45 |
+
# class CustomBERTModel(nn.Module):
|
46 |
+
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
47 |
+
# super(CustomBERTModel, self).__init__()
|
48 |
+
# hidden_size = 764 # Ensure this is 768
|
49 |
+
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
|
50 |
+
|
51 |
+
# # Load the pre-trained model's state_dict
|
52 |
+
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
53 |
+
# if isinstance(checkpoint, dict):
|
54 |
+
# self.bert.load_state_dict(checkpoint)
|
55 |
+
# else:
|
56 |
+
# raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
|
57 |
+
|
58 |
+
# # Fully connected layer with input size 768
|
59 |
+
# self.fc = nn.Linear(hidden_size, output_dim)
|
60 |
+
|
61 |
+
# def forward(self, sequence, segment_info):
|
62 |
+
# sequence = sequence.to(next(self.parameters()).device)
|
63 |
+
# segment_info = segment_info.to(sequence.device)
|
64 |
+
|
65 |
+
# x = self.bert(sequence, segment_info)
|
66 |
+
# print(f"BERT output shape: {x.shape}") # Should output (batch_size, seq_len, 768)
|
67 |
+
|
68 |
+
# cls_embeddings = x[:, 0] # Extract CLS token embeddings
|
69 |
+
# print(f"CLS Embeddings shape: {cls_embeddings.shape}") # Should output (batch_size, 768)
|
70 |
+
|
71 |
+
# logits = self.fc(cls_embeddings) # Should now pass a tensor of size (batch_size, 768) to `fc`
|
72 |
+
|
73 |
+
# return logits
|
74 |
+
|
75 |
+
|
76 |
+
# for test
|
77 |
+
class CustomBERTModel(nn.Module):
|
78 |
+
def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
79 |
+
super(CustomBERTModel, self).__init__()
|
80 |
+
self.hidden = 764 # Ensure this is defined correctly
|
81 |
+
self.bert = BERT(vocab_size=vocab_size, hidden=self.hidden, n_layers=12, attn_heads=12, dropout=0.1)
|
82 |
+
|
83 |
+
# Load the pre-trained model's state_dict
|
84 |
+
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
85 |
+
if isinstance(checkpoint, dict):
|
86 |
+
self.bert.load_state_dict(checkpoint)
|
87 |
+
else:
|
88 |
+
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
|
89 |
+
|
90 |
+
self.fc = nn.Linear(self.hidden, output_dim)
|
91 |
+
|
92 |
+
def forward(self, sequence, segment_info):
|
93 |
+
x = self.bert(sequence, segment_info)
|
94 |
+
cls_embeddings = x[:, 0] # Extract CLS token embeddings
|
95 |
+
logits = self.fc(cls_embeddings) # Pass to fully connected layer
|
96 |
+
return logits
|
97 |
+
|
98 |
+
def preprocess_labels(label_csv_path):
|
99 |
+
try:
|
100 |
+
labels_df = pd.read_csv(label_csv_path)
|
101 |
+
labels = labels_df['last_hint_class'].values.astype(int)
|
102 |
+
return torch.tensor(labels, dtype=torch.long)
|
103 |
+
except Exception as e:
|
104 |
+
print(f"Error reading dataset file: {e}")
|
105 |
+
return None
|
106 |
+
|
107 |
+
|
108 |
+
def preprocess_data(data_path, vocab, max_length=128):
|
109 |
+
try:
|
110 |
+
with open(data_path, 'r') as f:
|
111 |
+
sequences = f.readlines()
|
112 |
+
except Exception as e:
|
113 |
+
print(f"Error reading data file: {e}")
|
114 |
+
return None, None
|
115 |
+
|
116 |
+
if len(sequences) == 0:
|
117 |
+
raise ValueError(f"No sequences found in data file {data_path}. Check the file content.")
|
118 |
+
|
119 |
+
tokenized_sequences = []
|
120 |
+
|
121 |
+
for sequence in sequences:
|
122 |
+
sequence = sequence.strip()
|
123 |
+
if sequence:
|
124 |
+
encoded = vocab.to_seq(sequence, seq_len=max_length)
|
125 |
+
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
|
126 |
+
segment_label = [0] * max_length
|
127 |
+
|
128 |
+
tokenized_sequences.append({
|
129 |
+
'input_ids': torch.tensor(encoded),
|
130 |
+
'segment_label': torch.tensor(segment_label)
|
131 |
+
})
|
132 |
+
|
133 |
+
if not tokenized_sequences:
|
134 |
+
raise ValueError("Tokenization resulted in an empty list. Check the sequences and tokenization logic.")
|
135 |
+
|
136 |
+
tokenized_sequences = [t for t in tokenized_sequences if len(t['input_ids']) == max_length]
|
137 |
+
|
138 |
+
if not tokenized_sequences:
|
139 |
+
raise ValueError("All tokenized sequences are of unexpected length. This suggests an issue with the tokenization logic.")
|
140 |
+
|
141 |
+
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
142 |
+
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
143 |
+
|
144 |
+
print(f"Input IDs shape: {input_ids.shape}")
|
145 |
+
print(f"Segment labels shape: {segment_labels.shape}")
|
146 |
+
|
147 |
+
return input_ids, segment_labels
|
148 |
+
|
149 |
+
|
150 |
+
def collate_fn(batch):
|
151 |
+
inputs = []
|
152 |
+
labels = []
|
153 |
+
segment_labels = []
|
154 |
+
|
155 |
+
for item in batch:
|
156 |
+
if item is None:
|
157 |
+
continue
|
158 |
+
|
159 |
+
if isinstance(item, dict):
|
160 |
+
inputs.append(item['input_ids'].unsqueeze(0))
|
161 |
+
labels.append(item['label'].unsqueeze(0))
|
162 |
+
segment_labels.append(item['segment_label'].unsqueeze(0))
|
163 |
+
|
164 |
+
if len(inputs) == 0 or len(segment_labels) == 0:
|
165 |
+
print("Empty batch encountered. Returning None to skip this batch.")
|
166 |
+
return None
|
167 |
+
|
168 |
+
try:
|
169 |
+
inputs = torch.cat(inputs, dim=0)
|
170 |
+
labels = torch.cat(labels, dim=0)
|
171 |
+
segment_labels = torch.cat(segment_labels, dim=0)
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Error concatenating tensors: {e}")
|
174 |
+
return None
|
175 |
+
|
176 |
+
return {
|
177 |
+
'input': inputs,
|
178 |
+
'label': labels,
|
179 |
+
'segment_label': segment_labels
|
180 |
+
}
|
181 |
+
|
182 |
+
def custom_collate_fn(batch):
|
183 |
+
processed_batch = collate_fn(batch)
|
184 |
+
|
185 |
+
if processed_batch is None or len(processed_batch['input']) == 0:
|
186 |
+
# Return a valid batch with at least one element instead of an empty one
|
187 |
+
return {
|
188 |
+
'input': torch.zeros((1, 128), dtype=torch.long),
|
189 |
+
'label': torch.zeros((1,), dtype=torch.long),
|
190 |
+
'segment_label': torch.zeros((1, 128), dtype=torch.long)
|
191 |
+
}
|
192 |
+
|
193 |
+
return processed_batch
|
194 |
+
|
195 |
+
|
196 |
+
def train_without_progress_status(trainer, epoch, shuffle):
|
197 |
+
for epoch_idx in range(epoch):
|
198 |
+
print(f"EP_train:{epoch_idx}:")
|
199 |
+
for batch in trainer.train_data:
|
200 |
+
if batch is None:
|
201 |
+
continue
|
202 |
+
|
203 |
+
# Check if batch is a string (indicating an issue)
|
204 |
+
if isinstance(batch, str):
|
205 |
+
print(f"Error: Received a string instead of a dictionary in batch: {batch}")
|
206 |
+
raise ValueError(f"Unexpected string in batch: {batch}")
|
207 |
+
|
208 |
+
# Validate the batch structure before passing to iteration
|
209 |
+
if isinstance(batch, dict):
|
210 |
+
# Verify that all expected keys are present and that the values are tensors
|
211 |
+
if all(key in batch for key in ['input_ids', 'segment_label', 'labels']):
|
212 |
+
if all(isinstance(batch[key], torch.Tensor) for key in batch):
|
213 |
+
try:
|
214 |
+
print(f"Batch Structure: {batch}") # Debugging batch before iteration
|
215 |
+
trainer.iteration(epoch_idx, batch)
|
216 |
+
except Exception as e:
|
217 |
+
print(f"Error during batch processing: {e}")
|
218 |
+
sys.stdout.flush()
|
219 |
+
raise e # Propagate the exception for better debugging
|
220 |
+
else:
|
221 |
+
print(f"Error: Expected all values in batch to be tensors, but got: {batch}")
|
222 |
+
raise ValueError("Batch contains non-tensor values.")
|
223 |
+
else:
|
224 |
+
print(f"Error: Batch missing expected keys. Batch keys: {batch.keys()}")
|
225 |
+
raise ValueError("Batch does not contain expected keys.")
|
226 |
+
else:
|
227 |
+
print(f"Error: Expected batch to be a dictionary but got {type(batch)} instead.")
|
228 |
+
raise ValueError(f"Invalid batch structure: {batch}")
|
229 |
+
|
230 |
+
# def main(opt):
|
231 |
+
# # device = torch.device("cpu")
|
232 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
233 |
+
|
234 |
+
# vocab = Vocab(opt.vocab_file)
|
235 |
+
# vocab.load_vocab()
|
236 |
+
|
237 |
+
# input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
|
238 |
+
# labels = preprocess_labels(opt.dataset)
|
239 |
+
|
240 |
+
# if input_ids is None or segment_labels is None or labels is None:
|
241 |
+
# print("Error in preprocessing data. Exiting.")
|
242 |
+
# return
|
243 |
+
|
244 |
+
# dataset = TensorDataset(input_ids, segment_labels, torch.tensor(labels, dtype=torch.long))
|
245 |
+
# val_size = len(dataset) - int(0.8 * len(dataset))
|
246 |
+
# val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
|
247 |
+
|
248 |
+
# train_dataloader = DataLoader(
|
249 |
+
# train_dataset,
|
250 |
+
# batch_size=32,
|
251 |
+
# shuffle=True,
|
252 |
+
# collate_fn=custom_collate_fn
|
253 |
+
# )
|
254 |
+
# val_dataloader = DataLoader(
|
255 |
+
# val_dataset,
|
256 |
+
# batch_size=32,
|
257 |
+
# shuffle=False,
|
258 |
+
# collate_fn=custom_collate_fn
|
259 |
+
# )
|
260 |
+
|
261 |
+
# custom_model = CustomBERTModel(
|
262 |
+
# vocab_size=len(vocab.vocab),
|
263 |
+
# output_dim=2,
|
264 |
+
# pre_trained_model_path=opt.pre_trained_model_path
|
265 |
+
# ).to(device)
|
266 |
+
|
267 |
+
# trainer = BERTFineTuneTrainer1(
|
268 |
+
# bert=custom_model.bert,
|
269 |
+
# vocab_size=len(vocab.vocab),
|
270 |
+
# train_dataloader=train_dataloader,
|
271 |
+
# test_dataloader=val_dataloader,
|
272 |
+
# lr=5e-5,
|
273 |
+
# num_labels=2,
|
274 |
+
# with_cuda=torch.cuda.is_available(),
|
275 |
+
# log_freq=10,
|
276 |
+
# workspace_name=opt.output_dir,
|
277 |
+
# log_folder_path=opt.log_folder_path
|
278 |
+
# )
|
279 |
+
|
280 |
+
# trainer.train(epoch=20)
|
281 |
+
|
282 |
+
# # os.makedirs(opt.output_dir, exist_ok=True)
|
283 |
+
# # output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model.pth')
|
284 |
+
# # torch.save(custom_model.state_dict(), output_model_file)
|
285 |
+
# # print(f'Model saved to {output_model_file}')
|
286 |
+
|
287 |
+
# os.makedirs(opt.output_dir, exist_ok=True)
|
288 |
+
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
|
289 |
+
# torch.save(custom_model, output_model_file)
|
290 |
+
# print(f'Model saved to {output_model_file}')
|
291 |
+
|
292 |
+
|
293 |
+
def main(opt):
|
294 |
+
# Set device to GPU if available, otherwise use CPU
|
295 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
296 |
+
|
297 |
+
print(torch.cuda.is_available()) # Should return True if GPU is available
|
298 |
+
print(torch.cuda.device_count())
|
299 |
+
|
300 |
+
# Load vocabulary
|
301 |
+
vocab = Vocab(opt.vocab_file)
|
302 |
+
vocab.load_vocab()
|
303 |
+
|
304 |
+
# Preprocess data and labels
|
305 |
+
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
|
306 |
+
labels = preprocess_labels(opt.dataset)
|
307 |
+
|
308 |
+
if input_ids is None or segment_labels is None or labels is None:
|
309 |
+
print("Error in preprocessing data. Exiting.")
|
310 |
+
return
|
311 |
+
|
312 |
+
# Transfer tensors to the correct device (GPU/CPU)
|
313 |
+
input_ids = input_ids.to(device)
|
314 |
+
segment_labels = segment_labels.to(device)
|
315 |
+
labels = torch.tensor(labels, dtype=torch.long).to(device)
|
316 |
+
|
317 |
+
# Create TensorDataset and split into train and validation sets
|
318 |
+
dataset = TensorDataset(input_ids, segment_labels, labels)
|
319 |
+
val_size = len(dataset) - int(0.8 * len(dataset))
|
320 |
+
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
|
321 |
+
|
322 |
+
# Create DataLoaders for training and validation
|
323 |
+
train_dataloader = DataLoader(
|
324 |
+
train_dataset,
|
325 |
+
batch_size=32,
|
326 |
+
shuffle=True,
|
327 |
+
collate_fn=custom_collate_fn
|
328 |
+
)
|
329 |
+
val_dataloader = DataLoader(
|
330 |
+
val_dataset,
|
331 |
+
batch_size=32,
|
332 |
+
shuffle=False,
|
333 |
+
collate_fn=custom_collate_fn
|
334 |
+
)
|
335 |
+
|
336 |
+
# Initialize custom BERT model and move it to the device
|
337 |
+
custom_model = CustomBERTModel(
|
338 |
+
vocab_size=len(vocab.vocab),
|
339 |
+
output_dim=2,
|
340 |
+
pre_trained_model_path=opt.pre_trained_model_path
|
341 |
+
).to(device)
|
342 |
+
|
343 |
+
# Initialize the fine-tuning trainer
|
344 |
+
trainer = BERTFineTuneTrainer1(
|
345 |
+
bert=custom_model.bert,
|
346 |
+
vocab_size=len(vocab.vocab),
|
347 |
+
train_dataloader=train_dataloader,
|
348 |
+
test_dataloader=val_dataloader,
|
349 |
+
lr=5e-5,
|
350 |
+
num_labels=2,
|
351 |
+
with_cuda=torch.cuda.is_available(),
|
352 |
+
log_freq=10,
|
353 |
+
workspace_name=opt.output_dir,
|
354 |
+
log_folder_path=opt.log_folder_path
|
355 |
+
)
|
356 |
+
|
357 |
+
# Train the model
|
358 |
+
trainer.train(epoch=20)
|
359 |
+
|
360 |
+
# Save the model to the specified output directory
|
361 |
+
# os.makedirs(opt.output_dir, exist_ok=True)
|
362 |
+
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
|
363 |
+
# torch.save(custom_model.state_dict(), output_model_file)
|
364 |
+
# print(f'Model saved to {output_model_file}')
|
365 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
366 |
+
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
|
367 |
+
torch.save(custom_model, output_model_file)
|
368 |
+
print(f'Model saved to {output_model_file}')
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == '__main__':
|
372 |
+
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
|
373 |
+
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
|
374 |
+
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
|
375 |
+
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
|
376 |
+
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
|
377 |
+
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
|
378 |
+
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct_logs', help='Path to the folder for saving logs.')
|
379 |
+
|
380 |
+
|
381 |
+
opt = parser.parse_args()
|
382 |
+
main(opt)
|
main.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from src.bert import BERT
|
8 |
+
from src.pretrainer import BERTTrainer, BERTFineTuneTrainer, BERTAttention
|
9 |
+
from src.dataset import PretrainerDataset, TokenizerDataset
|
10 |
+
from src.vocab import Vocab
|
11 |
+
|
12 |
+
import time
|
13 |
+
import os
|
14 |
+
import tqdm
|
15 |
+
import pickle
|
16 |
+
|
17 |
+
def train():
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
|
20 |
+
parser.add_argument('-workspace_name', type=str, default=None)
|
21 |
+
parser.add_argument('-code', type=str, default=None, help="folder for pretraining outputs and logs")
|
22 |
+
parser.add_argument('-finetune_task', type=str, default=None, help="folder inside finetuning")
|
23 |
+
parser.add_argument("-attention", type=bool, default=False, help="analyse attention scores")
|
24 |
+
parser.add_argument("-diff_test_folder", type=bool, default=False, help="use for different test folder")
|
25 |
+
parser.add_argument("-embeddings", type=bool, default=False, help="get and analyse embeddings")
|
26 |
+
parser.add_argument('-embeddings_file_name', type=str, default=None, help="file name of embeddings")
|
27 |
+
parser.add_argument("-pretrain", type=bool, default=False, help="pretraining: true, or false")
|
28 |
+
# parser.add_argument('-opts', nargs='+', type=str, default=None, help='List of optional steps')
|
29 |
+
parser.add_argument("-max_mask", type=int, default=0.15, help="% of input tokens selected for masking")
|
30 |
+
# parser.add_argument("-p", "--pretrain_dataset", type=str, default="pretraining/pretrain.txt", help="pretraining dataset for bert")
|
31 |
+
# parser.add_argument("-pv", "--pretrain_val_dataset", type=str, default="pretraining/test.txt", help="pretraining validation dataset for bert")
|
32 |
+
# default="finetuning/test.txt",
|
33 |
+
parser.add_argument("-vocab_path", type=str, default="pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
34 |
+
|
35 |
+
parser.add_argument("-train_dataset_path", type=str, default="train.txt", help="fine tune train dataset for progress classifier")
|
36 |
+
parser.add_argument("-val_dataset_path", type=str, default="val.txt", help="test set for evaluate fine tune train set")
|
37 |
+
parser.add_argument("-test_dataset_path", type=str, default="test.txt", help="test set for evaluate fine tune train set")
|
38 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
39 |
+
parser.add_argument("-train_label_path", type=str, default="train_label.txt", help="fine tune train dataset for progress classifier")
|
40 |
+
parser.add_argument("-val_label_path", type=str, default="val_label.txt", help="test set for evaluate fine tune train set")
|
41 |
+
parser.add_argument("-test_label_path", type=str, default="test_label.txt", help="test set for evaluate fine tune train set")
|
42 |
+
##### change Checkpoint for finetuning
|
43 |
+
parser.add_argument("-pretrained_bert_checkpoint", type=str, default=None, help="checkpoint of saved pretrained bert model") #."output_feb09/bert_trained.model.ep40"
|
44 |
+
parser.add_argument('-check_epoch', type=int, default=None)
|
45 |
+
|
46 |
+
parser.add_argument("-hs", "--hidden", type=int, default=64, help="hidden size of transformer model") #64
|
47 |
+
parser.add_argument("-l", "--layers", type=int, default=4, help="number of layers") #4
|
48 |
+
parser.add_argument("-a", "--attn_heads", type=int, default=4, help="number of attention heads") #8
|
49 |
+
parser.add_argument("-s", "--seq_len", type=int, default=50, help="maximum sequence length")
|
50 |
+
|
51 |
+
parser.add_argument("-b", "--batch_size", type=int, default=500, help="number of batch_size") #64
|
52 |
+
parser.add_argument("-e", "--epochs", type=int, default=50)#1501, help="number of epochs") #501
|
53 |
+
# Use 50 for pretrain, and 10 for fine tune
|
54 |
+
parser.add_argument("-w", "--num_workers", type=int, default=4, help="dataloader worker size")
|
55 |
+
|
56 |
+
# Later run with cuda
|
57 |
+
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
|
58 |
+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
|
59 |
+
# parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
|
60 |
+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
|
61 |
+
# parser.add_argument("--on_memory", type=bool, default=False, help="Loading on memory: true or false")
|
62 |
+
|
63 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout of network")
|
64 |
+
parser.add_argument("--lr", type=float, default=1e-05, help="learning rate of adam") #1e-3
|
65 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
|
66 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
|
67 |
+
parser.add_argument("--adam_beta2", type=float, default=0.98, help="adam first beta value") #0.999
|
68 |
+
|
69 |
+
parser.add_argument("-o", "--output_path", type=str, default="bert_trained.seq_encoder.model", help="ex)output/bert.model")
|
70 |
+
# parser.add_argument("-o", "--output_path", type=str, default="output/bert_fine_tuned.model", help="ex)output/bert.model")
|
71 |
+
|
72 |
+
args = parser.parse_args()
|
73 |
+
for k,v in vars(args).items():
|
74 |
+
if 'path' in k:
|
75 |
+
if v:
|
76 |
+
if k == "output_path":
|
77 |
+
if args.code:
|
78 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.code}/"+v)
|
79 |
+
elif args.finetune_task:
|
80 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.finetune_task}/"+v)
|
81 |
+
else:
|
82 |
+
setattr(args, f"{k}", args.workspace_name+"/output/"+v)
|
83 |
+
elif k != "vocab_path":
|
84 |
+
if args.pretrain:
|
85 |
+
setattr(args, f"{k}", args.workspace_name+"/pretraining/"+v)
|
86 |
+
else:
|
87 |
+
if args.code:
|
88 |
+
setattr(args, f"{k}", args.workspace_name+f"/{args.code}/"+v)
|
89 |
+
elif args.finetune_task:
|
90 |
+
if args.diff_test_folder and "test" in k:
|
91 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/"+v)
|
92 |
+
else:
|
93 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/{args.finetune_task}/"+v)
|
94 |
+
else:
|
95 |
+
setattr(args, f"{k}", args.workspace_name+"/finetuning/"+v)
|
96 |
+
else:
|
97 |
+
setattr(args, f"{k}", args.workspace_name+"/"+v)
|
98 |
+
|
99 |
+
print(f"args.{k} : {getattr(args, f'{k}')}")
|
100 |
+
|
101 |
+
print("Loading Vocab", args.vocab_path)
|
102 |
+
vocab_obj = Vocab(args.vocab_path)
|
103 |
+
vocab_obj.load_vocab()
|
104 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
105 |
+
|
106 |
+
if args.attention:
|
107 |
+
print(f"Attention aggregate...... code: {args.code}, dataset: {args.finetune_task}")
|
108 |
+
if args.code:
|
109 |
+
new_folder = f"{args.workspace_name}/plots/{args.code}/"
|
110 |
+
if not os.path.exists(new_folder):
|
111 |
+
os.makedirs(new_folder)
|
112 |
+
|
113 |
+
train_dataset = TokenizerDataset(args.train_dataset_path, None, vocab_obj, seq_len=args.seq_len)
|
114 |
+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
115 |
+
print("Load Pre-trained BERT model")
|
116 |
+
cuda_condition = torch.cuda.is_available() and args.with_cuda
|
117 |
+
device = torch.device("cuda:0" if cuda_condition else "cpu")
|
118 |
+
bert = torch.load(args.pretrained_bert_checkpoint, map_location=device)
|
119 |
+
trainer = BERTAttention(bert, vocab_obj, train_dataloader = train_data_loader, workspace_name = args.workspace_name, code=args.code, finetune_task = args.finetune_task)
|
120 |
+
trainer.getAttention()
|
121 |
+
|
122 |
+
elif args.embeddings:
|
123 |
+
print("Get embeddings... and cluster... ")
|
124 |
+
train_dataset = TokenizerDataset(args.test_dataset_path, None, vocab_obj, seq_len=args.seq_len)
|
125 |
+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
126 |
+
print("Load Pre-trained BERT model")
|
127 |
+
cuda_condition = torch.cuda.is_available() and args.with_cuda
|
128 |
+
device = torch.device("cuda:0" if cuda_condition else "cpu")
|
129 |
+
bert = torch.load(args.pretrained_bert_checkpoint).to(device)
|
130 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
131 |
+
if torch.cuda.device_count() > 1:
|
132 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
133 |
+
bert = nn.DataParallel(bert, device_ids=available_gpus)
|
134 |
+
|
135 |
+
data_iter = tqdm.tqdm(enumerate(train_data_loader),
|
136 |
+
desc="Model: %s" % (args.pretrained_bert_checkpoint.split("/")[-1]),
|
137 |
+
total=len(train_data_loader), bar_format="{l_bar}{r_bar}")
|
138 |
+
all_embeddings = []
|
139 |
+
for i, data in data_iter:
|
140 |
+
data = {key: value.to(device) for key, value in data.items()}
|
141 |
+
embedding = bert(data["input"], data["segment_label"])
|
142 |
+
# print(embedding.shape, embedding[:, 0].shape)
|
143 |
+
embeddings = [h for h in embedding[:,0].cpu().detach().numpy()]
|
144 |
+
all_embeddings.extend(embeddings)
|
145 |
+
|
146 |
+
new_emb_folder = f"{args.workspace_name}/embeddings"
|
147 |
+
if not os.path.exists(new_emb_folder):
|
148 |
+
os.makedirs(new_emb_folder)
|
149 |
+
pickle.dump(all_embeddings, open(f"{new_emb_folder}/{args.embeddings_file_name}.pkl", "wb"))
|
150 |
+
else:
|
151 |
+
if args.pretrain:
|
152 |
+
print("Pre-training......")
|
153 |
+
print("Loading Pretraining Train Dataset", args.train_dataset_path)
|
154 |
+
print(f"Workspace: {args.workspace_name}")
|
155 |
+
pretrain_dataset = PretrainerDataset(args.train_dataset_path, vocab_obj, seq_len=args.seq_len, max_mask = args.max_mask)
|
156 |
+
|
157 |
+
print("Loading Pretraining Validation Dataset", args.val_dataset_path)
|
158 |
+
pretrain_valid_dataset = PretrainerDataset(args.val_dataset_path, vocab_obj, seq_len=args.seq_len, max_mask = args.max_mask) \
|
159 |
+
if args.val_dataset_path is not None else None
|
160 |
+
|
161 |
+
print("Loading Pretraining Test Dataset", args.test_dataset_path)
|
162 |
+
pretrain_test_dataset = PretrainerDataset(args.test_dataset_path, vocab_obj, seq_len=args.seq_len, max_mask = args.max_mask) \
|
163 |
+
if args.test_dataset_path is not None else None
|
164 |
+
|
165 |
+
print("Creating Dataloader")
|
166 |
+
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
167 |
+
pretrain_val_data_loader = DataLoader(pretrain_valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers)\
|
168 |
+
if pretrain_valid_dataset is not None else None
|
169 |
+
pretrain_test_data_loader = DataLoader(pretrain_test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)\
|
170 |
+
if pretrain_test_dataset is not None else None
|
171 |
+
|
172 |
+
print("Building BERT model")
|
173 |
+
bert = BERT(len(vocab_obj.vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads, dropout=args.dropout)
|
174 |
+
|
175 |
+
if args.pretrained_bert_checkpoint:
|
176 |
+
print(f"BERT model : {args.pretrained_bert_checkpoint}")
|
177 |
+
bert = torch.load(args.pretrained_bert_checkpoint)
|
178 |
+
|
179 |
+
new_log_folder = f"{args.workspace_name}/logs"
|
180 |
+
new_output_folder = f"{args.workspace_name}/output"
|
181 |
+
if args.code: # is sent almost all the time
|
182 |
+
new_log_folder = f"{args.workspace_name}/logs/{args.code}"
|
183 |
+
new_output_folder = f"{args.workspace_name}/output/{args.code}"
|
184 |
+
|
185 |
+
if not os.path.exists(new_log_folder):
|
186 |
+
os.makedirs(new_log_folder)
|
187 |
+
if not os.path.exists(new_output_folder):
|
188 |
+
os.makedirs(new_output_folder)
|
189 |
+
|
190 |
+
print(f"Creating BERT Trainer .... masking: True, max_mask: {args.max_mask}")
|
191 |
+
trainer = BERTTrainer(bert, len(vocab_obj.vocab), train_dataloader=pretrain_data_loader,
|
192 |
+
val_dataloader=pretrain_val_data_loader, test_dataloader=pretrain_test_data_loader,
|
193 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
194 |
+
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq,
|
195 |
+
log_folder_path=new_log_folder)
|
196 |
+
|
197 |
+
start_time = time.time()
|
198 |
+
print(f'Pretraining Starts, Time: {time.strftime("%D %T", time.localtime(start_time))}')
|
199 |
+
# if need to pretrain from a check-point, need :check_epoch
|
200 |
+
repoch = range(args.check_epoch, args.epochs) if args.check_epoch else range(args.epochs)
|
201 |
+
counter = 0
|
202 |
+
patience = 20
|
203 |
+
for epoch in repoch:
|
204 |
+
print(f'Training Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
205 |
+
trainer.train(epoch)
|
206 |
+
print(f'Training Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
207 |
+
|
208 |
+
if pretrain_val_data_loader is not None:
|
209 |
+
print(f'Validation Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
210 |
+
trainer.val(epoch)
|
211 |
+
print(f'Validation Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
212 |
+
|
213 |
+
if trainer.save_model: # or epoch%10 == 0 and epoch > 4
|
214 |
+
trainer.save(epoch, args.output_path)
|
215 |
+
counter = 0
|
216 |
+
if pretrain_test_data_loader is not None:
|
217 |
+
print(f'Test Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
218 |
+
trainer.test(epoch)
|
219 |
+
print(f'Test Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
220 |
+
else:
|
221 |
+
counter +=1
|
222 |
+
if counter >= patience:
|
223 |
+
print(f"Early stopping at epoch {epoch}")
|
224 |
+
break
|
225 |
+
|
226 |
+
end_time = time.time()
|
227 |
+
print("Time Taken to pretrain model = ", end_time - start_time)
|
228 |
+
print(f'Pretraining Ends, Time: {time.strftime("%D %T", time.localtime(end_time))}')
|
229 |
+
else:
|
230 |
+
print("Fine Tuning......")
|
231 |
+
print("Loading Train Dataset", args.train_dataset_path)
|
232 |
+
train_dataset = TokenizerDataset(args.train_dataset_path, args.train_label_path, vocab_obj, seq_len=args.seq_len)
|
233 |
+
|
234 |
+
# print("Loading Validation Dataset", args.val_dataset_path)
|
235 |
+
# val_dataset = TokenizerDataset(args.val_dataset_path, args.val_label_path, vocab_obj, seq_len=args.seq_len) \
|
236 |
+
# if args.val_dataset_path is not None else None
|
237 |
+
|
238 |
+
print("Loading Test Dataset", args.test_dataset_path)
|
239 |
+
test_dataset = TokenizerDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len) \
|
240 |
+
if args.test_dataset_path is not None else None
|
241 |
+
|
242 |
+
print("Creating Dataloader...")
|
243 |
+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
244 |
+
# val_data_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
|
245 |
+
# if val_dataset is not None else None
|
246 |
+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
|
247 |
+
if test_dataset is not None else None
|
248 |
+
|
249 |
+
print("Load Pre-trained BERT model")
|
250 |
+
# bert = BERT(len(vocab_obj.vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
|
251 |
+
cuda_condition = torch.cuda.is_available() and args.with_cuda
|
252 |
+
device = torch.device("cuda:0" if cuda_condition else "cpu")
|
253 |
+
bert = torch.load(args.pretrained_bert_checkpoint, map_location=device)
|
254 |
+
|
255 |
+
# if args.finetune_task == "SL":
|
256 |
+
# if args.workspace_name == "ratio_proportion_change4":
|
257 |
+
# num_labels = 9
|
258 |
+
# elif args.workspace_name == "ratio_proportion_change3":
|
259 |
+
# num_labels = 9
|
260 |
+
# elif args.workspace_name == "scale_drawings_3":
|
261 |
+
# num_labels = 9
|
262 |
+
# elif args.workspace_name == "sales_tax_discounts_two_rates":
|
263 |
+
# num_labels = 3
|
264 |
+
# else:
|
265 |
+
# num_labels = 2
|
266 |
+
# # num_labels = 1
|
267 |
+
# print(f"Number of Labels : {args.num_labels}")
|
268 |
+
new_log_folder = f"{args.workspace_name}/logs"
|
269 |
+
new_output_folder = f"{args.workspace_name}/output"
|
270 |
+
if args.finetune_task: # is sent almost all the time
|
271 |
+
new_log_folder = f"{args.workspace_name}/logs/{args.finetune_task}"
|
272 |
+
new_output_folder = f"{args.workspace_name}/output/{args.finetune_task}"
|
273 |
+
|
274 |
+
if not os.path.exists(new_log_folder):
|
275 |
+
os.makedirs(new_log_folder)
|
276 |
+
if not os.path.exists(new_output_folder):
|
277 |
+
os.makedirs(new_output_folder)
|
278 |
+
|
279 |
+
print("Creating BERT Fine Tune Trainer")
|
280 |
+
trainer = BERTFineTuneTrainer(bert, len(vocab_obj.vocab),
|
281 |
+
train_dataloader=train_data_loader, test_dataloader=test_data_loader,
|
282 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
283 |
+
with_cuda=args.with_cuda, cuda_devices = args.cuda_devices, log_freq=args.log_freq,
|
284 |
+
workspace_name = args.workspace_name, num_labels=args.num_labels, log_folder_path=new_log_folder)
|
285 |
+
|
286 |
+
print("Fine-tune training Start....")
|
287 |
+
start_time = time.time()
|
288 |
+
repoch = range(args.check_epoch, args.epochs) if args.check_epoch else range(args.epochs)
|
289 |
+
counter = 0
|
290 |
+
patience = 10
|
291 |
+
for epoch in repoch:
|
292 |
+
print(f'Training Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
293 |
+
trainer.train(epoch)
|
294 |
+
print(f'Training Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
295 |
+
|
296 |
+
if test_data_loader is not None:
|
297 |
+
print(f'Test Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
298 |
+
trainer.test(epoch)
|
299 |
+
# pickle.dump(trainer.probability_list, open(f"{args.workspace_name}/output/aaai/change4_mid_prob_{epoch}.pkl","wb"))
|
300 |
+
print(f'Test Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
301 |
+
|
302 |
+
# if val_data_loader is not None:
|
303 |
+
# print(f'Validation Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
304 |
+
# trainer.val(epoch)
|
305 |
+
# print(f'Validation Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
306 |
+
|
307 |
+
if trainer.save_model: # or epoch%10 == 0
|
308 |
+
trainer.save(epoch, args.output_path)
|
309 |
+
counter = 0
|
310 |
+
else:
|
311 |
+
counter +=1
|
312 |
+
if counter >= patience:
|
313 |
+
print(f"Early stopping at epoch {epoch}")
|
314 |
+
break
|
315 |
+
|
316 |
+
end_time = time.time()
|
317 |
+
print("Time Taken to fine-tune model = ", end_time - start_time)
|
318 |
+
print(f'Pretraining Ends, Time: {time.strftime("%D %T", time.localtime(end_time))}')
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == "__main__":
|
322 |
+
train()
|
metrics.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.special import softmax
|
3 |
+
|
4 |
+
|
5 |
+
class CELoss(object):
|
6 |
+
|
7 |
+
def compute_bin_boundaries(self, probabilities = np.array([])):
|
8 |
+
|
9 |
+
#uniform bin spacing
|
10 |
+
if probabilities.size == 0:
|
11 |
+
bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
|
12 |
+
self.bin_lowers = bin_boundaries[:-1]
|
13 |
+
self.bin_uppers = bin_boundaries[1:]
|
14 |
+
else:
|
15 |
+
#size of bins
|
16 |
+
bin_n = int(self.n_data/self.n_bins)
|
17 |
+
|
18 |
+
bin_boundaries = np.array([])
|
19 |
+
|
20 |
+
probabilities_sort = np.sort(probabilities)
|
21 |
+
|
22 |
+
for i in range(0,self.n_bins):
|
23 |
+
bin_boundaries = np.append(bin_boundaries,probabilities_sort[i*bin_n])
|
24 |
+
bin_boundaries = np.append(bin_boundaries,1.0)
|
25 |
+
|
26 |
+
self.bin_lowers = bin_boundaries[:-1]
|
27 |
+
self.bin_uppers = bin_boundaries[1:]
|
28 |
+
|
29 |
+
|
30 |
+
def get_probabilities(self, output, labels, logits):
|
31 |
+
#If not probabilities apply softmax!
|
32 |
+
if logits:
|
33 |
+
self.probabilities = softmax(output, axis=1)
|
34 |
+
else:
|
35 |
+
self.probabilities = output
|
36 |
+
|
37 |
+
self.labels = labels
|
38 |
+
self.confidences = np.max(self.probabilities, axis=1)
|
39 |
+
self.predictions = np.argmax(self.probabilities, axis=1)
|
40 |
+
self.accuracies = np.equal(self.predictions,labels)
|
41 |
+
|
42 |
+
def binary_matrices(self):
|
43 |
+
idx = np.arange(self.n_data)
|
44 |
+
#make matrices of zeros
|
45 |
+
pred_matrix = np.zeros([self.n_data,self.n_class])
|
46 |
+
label_matrix = np.zeros([self.n_data,self.n_class])
|
47 |
+
#self.acc_matrix = np.zeros([self.n_data,self.n_class])
|
48 |
+
pred_matrix[idx,self.predictions] = 1
|
49 |
+
label_matrix[idx,self.labels] = 1
|
50 |
+
|
51 |
+
self.acc_matrix = np.equal(pred_matrix, label_matrix)
|
52 |
+
|
53 |
+
|
54 |
+
def compute_bins(self, index = None):
|
55 |
+
self.bin_prop = np.zeros(self.n_bins)
|
56 |
+
self.bin_acc = np.zeros(self.n_bins)
|
57 |
+
self.bin_conf = np.zeros(self.n_bins)
|
58 |
+
self.bin_score = np.zeros(self.n_bins)
|
59 |
+
|
60 |
+
if index == None:
|
61 |
+
confidences = self.confidences
|
62 |
+
accuracies = self.accuracies
|
63 |
+
else:
|
64 |
+
confidences = self.probabilities[:,index]
|
65 |
+
accuracies = self.acc_matrix[:,index]
|
66 |
+
|
67 |
+
|
68 |
+
for i, (bin_lower, bin_upper) in enumerate(zip(self.bin_lowers, self.bin_uppers)):
|
69 |
+
# Calculated |confidence - accuracy| in each bin
|
70 |
+
in_bin = np.greater(confidences,bin_lower.item()) * np.less_equal(confidences,bin_upper.item())
|
71 |
+
self.bin_prop[i] = np.mean(in_bin)
|
72 |
+
|
73 |
+
if self.bin_prop[i].item() > 0:
|
74 |
+
self.bin_acc[i] = np.mean(accuracies[in_bin])
|
75 |
+
self.bin_conf[i] = np.mean(confidences[in_bin])
|
76 |
+
self.bin_score[i] = np.abs(self.bin_conf[i] - self.bin_acc[i])
|
77 |
+
|
78 |
+
class MaxProbCELoss(CELoss):
|
79 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
80 |
+
self.n_bins = n_bins
|
81 |
+
super().compute_bin_boundaries()
|
82 |
+
super().get_probabilities(output, labels, logits)
|
83 |
+
super().compute_bins()
|
84 |
+
|
85 |
+
#http://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf
|
86 |
+
class ECELoss(MaxProbCELoss):
|
87 |
+
|
88 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
89 |
+
super().loss(output, labels, n_bins, logits)
|
90 |
+
return np.dot(self.bin_prop,self.bin_score)
|
91 |
+
|
92 |
+
class MCELoss(MaxProbCELoss):
|
93 |
+
|
94 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
95 |
+
super().loss(output, labels, n_bins, logits)
|
96 |
+
return np.max(self.bin_score)
|
97 |
+
|
98 |
+
#https://arxiv.org/abs/1905.11001
|
99 |
+
#Overconfidence Loss (Good in high risk applications where confident but wrong predictions can be especially harmful)
|
100 |
+
class OELoss(MaxProbCELoss):
|
101 |
+
|
102 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
103 |
+
super().loss(output, labels, n_bins, logits)
|
104 |
+
return np.dot(self.bin_prop,self.bin_conf * np.maximum(self.bin_conf-self.bin_acc,np.zeros(self.n_bins)))
|
105 |
+
|
106 |
+
|
107 |
+
#https://arxiv.org/abs/1904.01685
|
108 |
+
class SCELoss(CELoss):
|
109 |
+
|
110 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
111 |
+
sce = 0.0
|
112 |
+
self.n_bins = n_bins
|
113 |
+
self.n_data = len(output)
|
114 |
+
self.n_class = len(output[0])
|
115 |
+
|
116 |
+
super().compute_bin_boundaries()
|
117 |
+
super().get_probabilities(output, labels, logits)
|
118 |
+
super().binary_matrices()
|
119 |
+
|
120 |
+
for i in range(self.n_class):
|
121 |
+
super().compute_bins(i)
|
122 |
+
sce += np.dot(self.bin_prop,self.bin_score)
|
123 |
+
|
124 |
+
return sce/self.n_class
|
125 |
+
|
126 |
+
class TACELoss(CELoss):
|
127 |
+
|
128 |
+
def loss(self, output, labels, threshold = 0.01, n_bins = 15, logits = True):
|
129 |
+
tace = 0.0
|
130 |
+
self.n_bins = n_bins
|
131 |
+
self.n_data = len(output)
|
132 |
+
self.n_class = len(output[0])
|
133 |
+
|
134 |
+
super().get_probabilities(output, labels, logits)
|
135 |
+
self.probabilities[self.probabilities < threshold] = 0
|
136 |
+
super().binary_matrices()
|
137 |
+
|
138 |
+
for i in range(self.n_class):
|
139 |
+
super().compute_bin_boundaries(self.probabilities[:,i])
|
140 |
+
super().compute_bins(i)
|
141 |
+
tace += np.dot(self.bin_prop,self.bin_score)
|
142 |
+
|
143 |
+
return tace/self.n_class
|
144 |
+
|
145 |
+
#create TACELoss with threshold fixed at 0
|
146 |
+
class ACELoss(TACELoss):
|
147 |
+
|
148 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
149 |
+
return super().loss(output, labels, 0.0 , n_bins, logits)
|
new_fine_tuning/README.md
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Pre-training Data
|
2 |
+
|
3 |
+
### ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
4 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -analyze_dataset_by_section True -workspace_name ratio_proportion_change3 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain1000.txt -train_info_path pretraining/pretrain1000_info.txt -test_file_path pretraining/test1000.txt -test_info_path pretraining/test1000_info.txt
|
5 |
+
|
6 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain2000.txt -train_info_path pretraining/pretrain2000_info.txt -test_file_path pretraining/test2000.txt -test_info_path pretraining/test2000_info.txt
|
7 |
+
|
8 |
+
#### Test simple
|
9 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code full -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path full.txt -train_info_path full_info.txt
|
10 |
+
|
11 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code gt -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path er.txt -train_info_path er_info.txt -test_file_path me.txt -test_info_path me_info.txt
|
12 |
+
|
13 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code correct -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path correct.txt -train_info_path correct_info.txt -test_file_path incorrect.txt -test_info_path incorrect_info.txt -final_step FinalAnswer
|
14 |
+
|
15 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code progress -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path graduated.txt -train_info_path graduated_info.txt -test_file_path promoted.txt -test_info_path promoted_info.txt
|
16 |
+
|
17 |
+
### ratio_proportion_change4 : Using Percents and Percent Change
|
18 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -analyze_dataset_by_section True -workspace_name ratio_proportion_change4 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor NumeratorLabel1 DenominatorLabel1 -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain1000.txt -train_info_path pretraining/pretrain1000_info.txt -test_file_path pretraining/test1000.txt -test_info_path pretraining/test1000_info.txt
|
19 |
+
|
20 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor NumeratorLabel1 DenominatorLabel1 -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain2000.txt -train_info_path pretraining/pretrain2000_info.txt -test_file_path pretraining/test2000.txt -test_info_path pretraining/test2000_info.txt
|
21 |
+
|
22 |
+
#### Test simple
|
23 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code full -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path full.txt -train_info_path full_info.txt
|
24 |
+
|
25 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code gt -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path er.txt -train_info_path er_info.txt -test_file_path me.txt -test_info_path me_info.txt
|
26 |
+
|
27 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code correct -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path correct.txt -train_info_path correct_info.txt -test_file_path incorrect.txt -test_info_path incorrect_info.txt -final_step FinalAnswer
|
28 |
+
|
29 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code progress -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path graduated.txt -train_info_path graduated_info.txt -test_file_path promoted.txt -test_info_path promoted_info.txt
|
30 |
+
|
31 |
+
## Pretraining
|
32 |
+
|
33 |
+
### ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
34 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3_1920 -code pretrain1000 --pretrain_dataset pretraining/pretrain1000.txt --pretrain_val_dataset pretraining/test1000.txt
|
35 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000 --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt
|
36 |
+
|
37 |
+
#### Test simple models
|
38 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 1
|
39 |
+
|
40 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 2
|
41 |
+
|
42 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 2
|
43 |
+
|
44 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 4
|
45 |
+
|
46 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 4
|
47 |
+
|
48 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 8
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
### ratio_proportion_change4 : Using Percents and Percent Change
|
53 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain1000 --pretrain_dataset pretraining/pretrain1000.txt --pretrain_val_dataset pretraining/test1000.txt
|
54 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000 --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt
|
55 |
+
|
56 |
+
#### Test simple models
|
57 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_1l1h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 1
|
58 |
+
|
59 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_1l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 2
|
60 |
+
|
61 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_2l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 2
|
62 |
+
|
63 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_2l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 4
|
64 |
+
|
65 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_4l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 4
|
66 |
+
|
67 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_4l8h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 8
|
68 |
+
|
69 |
+
|
70 |
+
## Preparing Fine Tuning Data
|
71 |
+
|
72 |
+
### ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
73 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -final_step FinalAnswer
|
74 |
+
|
75 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task check2 --train_dataset finetuning/check2/train.txt --test_dataset finetuning/check2/test.txt --train_label finetuning/check2/train_label.txt --test_label finetuning/check2/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
76 |
+
|
77 |
+
#### Attention Head Check
|
78 |
+
<!-- > PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_1 EquationAnswer NumeratorFactor EquationAnswer NumeratorFactor EquationAnswer NumeratorFactor DenominatorFactor NumeratorFactor DenominatorFactor NumeratorFactor DenominatorFactor FirstRow1:2 FirstRow1:1 FirstRow2:1 FirstRow2:2 FirstRow2:1 SecondRow ThirdRow FinalAnswerDirection ThirdRow FinalAnswer -->
|
79 |
+
|
80 |
+
|
81 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task er ;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task correct ;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task promoted
|
82 |
+
|
83 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task promoted
|
84 |
+
|
85 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task promoted
|
86 |
+
|
87 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task promoted
|
88 |
+
|
89 |
+
<!-- > clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep923 --attention True -->
|
90 |
+
|
91 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task promoted
|
92 |
+
|
93 |
+
clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset full/full_attn.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task full
|
94 |
+
|
95 |
+
|
96 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task promoted
|
97 |
+
|
98 |
+
|
99 |
+
<!-- PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_2 FirstRow2:1 FirstRow2:2 FirstRow1:1 SecondRow ThirdRow FinalAnswer FinalAnswerDirection --> me
|
100 |
+
|
101 |
+
<!-- PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_1 DenominatorFactor NumeratorFactor OptionalTask_2 EquationAnswer FirstRow1:1 FirstRow1:2 FirstRow2:2 FirstRow2:1 FirstRow1:2 SecondRow ThirdRow FinalAnswer --> er
|
102 |
+
|
103 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset pretraining/attention_train.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep273 --attention True
|
104 |
+
|
105 |
+
<!-- PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_1 DenominatorFactor NumeratorFactor OptionalTask_2 EquationAnswer FirstRow1:1 FirstRow1:2 FirstRow2:2 FirstRow2:1 FirstRow1:2 SecondRow ThirdRow FinalAnswer -->
|
106 |
+
|
107 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset pretraining/attention_train.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep1021 --attention True
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
### ratio_proportion_change4 : Using Percents and Percent Change
|
112 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor NumeratorLabel1 DenominatorLabel1 -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -final_step FinalAnswer
|
113 |
+
|
114 |
+
### scale_drawings_3 : Calculating Measurements Using a Scale
|
115 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name scale_drawings_3 -opt_step1 opt1-check opt1-ratio-L-n opt1-ratio-L-d opt1-ratio-R-n opt1-ratio-R-d opt1-me2-top-3 opt1-me2-top-4 opt1-me2-top-2 opt1-me2-top-1 opt1-me2-middle-1 opt1-me2-bottom-1 -opt_step2 opt2-check opt2-ratio-L-n opt2-ratio-L-d opt2-ratio-R-n opt2-ratio-R-d opt2-me2-top-3 opt2-me2-top-4 opt2-me2-top-1 opt2-me2-top-2 opt2-me2-middle-1 opt2-me2-bottom-1 -final_step unk-value1 unk-value2
|
116 |
+
|
117 |
+
### sales_tax_discounts_two_rates : Solving Problems with Both Sales Tax and Discounts
|
118 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name sales_tax_discounts_two_rates -opt_step1 optionalTaskGn salestaxFactor2 discountFactor2 multiplyOrderStatementGn -final_step totalCost1
|
119 |
+
|
120 |
+
|
121 |
+
# Fine Tuning Pre-trained model
|
122 |
+
|
123 |
+
## ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
124 |
+
> Selected Pretrained model: **ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279**
|
125 |
+
> New **bert/ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731**
|
126 |
+
|
127 |
+
### 10per
|
128 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731 --epochs 51
|
129 |
+
|
130 |
+
### IS
|
131 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task IS --train_dataset finetuning/IS/train.txt --test_dataset finetuning/FS/train.txt --train_label finetuning/IS/train_label.txt --test_label finetuning/FS/train_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731 --epochs 51
|
132 |
+
|
133 |
+
### FS
|
134 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task FS --train_dataset finetuning/FS/train.txt --test_dataset finetuning/IS/train.txt --train_label finetuning/FS/train_label.txt --test_label finetuning/IS/train_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731 --epochs 51
|
135 |
+
|
136 |
+
### correctness
|
137 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
138 |
+
|
139 |
+
### SL
|
140 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task SL --train_dataset finetuning/SL/train.txt --test_dataset finetuning/SL/test.txt --train_label finetuning/SL/train_label.txt --test_label finetuning/SL/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
141 |
+
|
142 |
+
### effectiveness
|
143 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task effectiveness --train_dataset finetuning/effectiveness/train.txt --test_dataset finetuning/effectiveness/test.txt --train_label finetuning/effectiveness/train_label.txt --test_label finetuning/effectiveness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
144 |
+
|
145 |
+
|
146 |
+
## ratio_proportion_change4 : Using Percents and Percent Change
|
147 |
+
> Selected Pretrained model: **ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287**
|
148 |
+
### 10per
|
149 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
150 |
+
|
151 |
+
### IS
|
152 |
+
|
153 |
+
### FS
|
154 |
+
|
155 |
+
### correctness
|
156 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
157 |
+
|
158 |
+
### SL
|
159 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task SL --train_dataset finetuning/SL/train.txt --test_dataset finetuning/SL/test.txt --train_label finetuning/SL/train_label.txt --test_label finetuning/SL/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
160 |
+
|
161 |
+
### effectiveness
|
162 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task effectiveness --train_dataset finetuning/effectiveness/train.txt --test_dataset finetuning/effectiveness/test.txt --train_label finetuning/effectiveness/train_label.txt --test_label finetuning/effectiveness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
163 |
+
|
164 |
+
|
165 |
+
## scale_drawings_3 : Calculating Measurements Using a Scale
|
166 |
+
> Selected Pretrained model: **scale_drawings_3/output/bert_trained.seq_encoder.model.ep252**
|
167 |
+
### 10per
|
168 |
+
> clear;python3 src/main.py -workspace_name scale_drawings_3 -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint scale_drawings_3/output/bert_trained.seq_encoder.model.ep252 --epochs 51
|
169 |
+
|
170 |
+
### IS
|
171 |
+
|
172 |
+
### FS
|
173 |
+
|
174 |
+
### correctness
|
175 |
+
> clear;python3 src/main.py -workspace_name scale_drawings_3 -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint scale_drawings_3/output/bert_trained.seq_encoder.model.ep252 --epochs 51
|
176 |
+
|
177 |
+
### SL
|
178 |
+
> clear;python3 src/main.py -workspace_name scale_drawings_3 -finetune_task SL --train_dataset finetuning/SL/train.txt --test_dataset finetuning/SL/test.txt --train_label finetuning/SL/train_label.txt --test_label finetuning/SL/test_label.txt --pretrained_bert_checkpoint scale_drawings_3/output/bert_trained.seq_encoder.model.ep252 --epochs 51
|
179 |
+
|
180 |
+
### effectiveness
|
181 |
+
|
182 |
+
## sales_tax_discounts_two_rates : Solving Problems with Both Sales Tax and Discounts
|
183 |
+
> Selected Pretrained model: **sales_tax_discounts_two_rates/output/bert_trained.seq_encoder.model.ep255**
|
184 |
+
|
185 |
+
### 10per
|
186 |
+
> clear;python3 src/main.py -workspace_name sales_tax_discounts_two_rates -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint sales_tax_discounts_two_rates/output/bert_trained.seq_encoder.model.ep255 --epochs 51
|
187 |
+
|
188 |
+
### IS
|
189 |
+
|
190 |
+
### FS
|
191 |
+
|
192 |
+
### correctness
|
193 |
+
> clear;python3 src/main.py -workspace_name sales_tax_discounts_two_rates -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint sales_tax_discounts_two_rates/output/bert_trained.seq_encoder.model.ep255 --epochs 51
|
194 |
+
|
195 |
+
### SL
|
196 |
+
|
197 |
+
### effectiveness
|
new_fine_tuning/__pycache__/metrics.cpython-312.pyc
ADDED
Binary file (9.16 kB). View file
|
|
new_fine_tuning/__pycache__/recalibration.cpython-312.pyc
ADDED
Binary file (5.51 kB). View file
|
|
new_fine_tuning/__pycache__/visualization.cpython-312.pyc
ADDED
Binary file (5.28 kB). View file
|
|
new_hint_fine_tuned.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.utils.data import DataLoader, random_split, TensorDataset
|
6 |
+
from src.dataset import TokenizerDataset
|
7 |
+
from src.bert import BERT
|
8 |
+
from src.pretrainer import BERTFineTuneTrainer1
|
9 |
+
from src.vocab import Vocab
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
def preprocess_labels(label_csv_path):
|
13 |
+
try:
|
14 |
+
labels_df = pd.read_csv(label_csv_path)
|
15 |
+
labels = labels_df['last_hint_class'].values.astype(int)
|
16 |
+
return torch.tensor(labels, dtype=torch.long)
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Error reading dataset file: {e}")
|
19 |
+
return None
|
20 |
+
|
21 |
+
def preprocess_data(data_path, vocab, max_length=128):
|
22 |
+
try:
|
23 |
+
with open(data_path, 'r') as f:
|
24 |
+
sequences = f.readlines()
|
25 |
+
except Exception as e:
|
26 |
+
print(f"Error reading data file: {e}")
|
27 |
+
return None, None
|
28 |
+
|
29 |
+
tokenized_sequences = []
|
30 |
+
for sequence in sequences:
|
31 |
+
sequence = sequence.strip()
|
32 |
+
if sequence:
|
33 |
+
encoded = vocab.to_seq(sequence, seq_len=max_length)
|
34 |
+
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
|
35 |
+
segment_label = [0] * max_length
|
36 |
+
|
37 |
+
tokenized_sequences.append({
|
38 |
+
'input_ids': torch.tensor(encoded),
|
39 |
+
'segment_label': torch.tensor(segment_label)
|
40 |
+
})
|
41 |
+
|
42 |
+
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
43 |
+
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
44 |
+
|
45 |
+
print(f"Input IDs shape: {input_ids.shape}")
|
46 |
+
print(f"Segment labels shape: {segment_labels.shape}")
|
47 |
+
|
48 |
+
return input_ids, segment_labels
|
49 |
+
|
50 |
+
def custom_collate_fn(batch):
|
51 |
+
inputs = [item['input_ids'].unsqueeze(0) for item in batch]
|
52 |
+
labels = [item['label'].unsqueeze(0) for item in batch]
|
53 |
+
segment_labels = [item['segment_label'].unsqueeze(0) for item in batch]
|
54 |
+
|
55 |
+
inputs = torch.cat(inputs, dim=0)
|
56 |
+
labels = torch.cat(labels, dim=0)
|
57 |
+
segment_labels = torch.cat(segment_labels, dim=0)
|
58 |
+
|
59 |
+
return {
|
60 |
+
'input': inputs,
|
61 |
+
'label': labels,
|
62 |
+
'segment_label': segment_labels
|
63 |
+
}
|
64 |
+
|
65 |
+
def main(opt):
|
66 |
+
# Set device to GPU if available, otherwise use CPU
|
67 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
68 |
+
|
69 |
+
# Load vocabulary
|
70 |
+
vocab = Vocab(opt.vocab_file)
|
71 |
+
vocab.load_vocab()
|
72 |
+
|
73 |
+
# Preprocess data and labels
|
74 |
+
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=50) # Using sequence length 50
|
75 |
+
labels = preprocess_labels(opt.dataset)
|
76 |
+
|
77 |
+
if input_ids is None or segment_labels is None or labels is None:
|
78 |
+
print("Error in preprocessing data. Exiting.")
|
79 |
+
return
|
80 |
+
|
81 |
+
# Create TensorDataset and split into train and validation sets
|
82 |
+
dataset = TensorDataset(input_ids, segment_labels, labels)
|
83 |
+
val_size = len(dataset) - int(0.8 * len(dataset))
|
84 |
+
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
|
85 |
+
|
86 |
+
# Create DataLoaders for training and validation
|
87 |
+
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
|
88 |
+
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
|
89 |
+
|
90 |
+
# Initialize custom BERT model and move it to the device
|
91 |
+
custom_model = CustomBERTModel(
|
92 |
+
vocab_size=len(vocab.vocab),
|
93 |
+
output_dim=2,
|
94 |
+
pre_trained_model_path=opt.pre_trained_model_path
|
95 |
+
).to(device)
|
96 |
+
|
97 |
+
# Initialize the fine-tuning trainer
|
98 |
+
trainer = BERTFineTuneTrainer1(
|
99 |
+
bert=custom_model,
|
100 |
+
vocab_size=len(vocab.vocab),
|
101 |
+
train_dataloader=train_dataloader,
|
102 |
+
test_dataloader=val_dataloader,
|
103 |
+
lr=1e-5, # Using learning rate 10^-5 as specified
|
104 |
+
num_labels=2,
|
105 |
+
with_cuda=torch.cuda.is_available(),
|
106 |
+
log_freq=10,
|
107 |
+
workspace_name=opt.output_dir,
|
108 |
+
log_folder_path=opt.log_folder_path
|
109 |
+
)
|
110 |
+
|
111 |
+
# Train the model
|
112 |
+
trainer.train(epoch=20)
|
113 |
+
|
114 |
+
# Save the model
|
115 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
116 |
+
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_3.pth')
|
117 |
+
torch.save(custom_model, output_model_file)
|
118 |
+
print(f'Model saved to {output_model_file}')
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
|
122 |
+
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
|
123 |
+
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
|
124 |
+
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
|
125 |
+
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
|
126 |
+
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
|
127 |
+
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct', help='Path to the folder for saving logs.')
|
128 |
+
|
129 |
+
|
130 |
+
opt = parser.parse_args()
|
131 |
+
main(opt)
|
new_test_saved_finetuned_model.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.optim import Adam
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import pickle
|
8 |
+
print("here1",os.getcwd())
|
9 |
+
from src.dataset import TokenizerDataset, TokenizerDatasetForCalibration
|
10 |
+
from src.vocab import Vocab
|
11 |
+
print("here3",os.getcwd())
|
12 |
+
from src.bert import BERT
|
13 |
+
from src.seq_model import BERTSM
|
14 |
+
from src.classifier_model import BERTForClassification, BERTForClassificationWithFeats
|
15 |
+
# from src.new_finetuning.optim_schedule import ScheduledOptim
|
16 |
+
import metrics, recalibration, visualization
|
17 |
+
from recalibration import ModelWithTemperature
|
18 |
+
import tqdm
|
19 |
+
import sys
|
20 |
+
import time
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, roc_auc_score
|
24 |
+
import matplotlib.pyplot as plt
|
25 |
+
import seaborn as sns
|
26 |
+
import pandas as pd
|
27 |
+
from collections import defaultdict
|
28 |
+
print("here3",os.getcwd())
|
29 |
+
class BERTFineTuneTrainer:
|
30 |
+
|
31 |
+
def __init__(self, bertFinetunedClassifierwithFeats: BERT, #BERTForClassificationWithFeats
|
32 |
+
vocab_size: int, test_dataloader: DataLoader = None,
|
33 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
34 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
35 |
+
num_labels=2, log_folder_path: str = None):
|
36 |
+
"""
|
37 |
+
:param bert: BERT model which you want to train
|
38 |
+
:param vocab_size: total word vocab size
|
39 |
+
:param test_dataloader: test dataset data loader [can be None]
|
40 |
+
:param lr: learning rate of optimizer
|
41 |
+
:param betas: Adam optimizer betas
|
42 |
+
:param weight_decay: Adam optimizer weight decay param
|
43 |
+
:param with_cuda: traning with cuda
|
44 |
+
:param log_freq: logging frequency of the batch iteration
|
45 |
+
"""
|
46 |
+
|
47 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
48 |
+
# cuda_condition = torch.cuda.is_available() and with_cuda
|
49 |
+
# self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
50 |
+
self.device = torch.device("cpu") #torch.device("cuda:0" if cuda_condition else "cpu")
|
51 |
+
# print(cuda_condition, " Device used = ", self.device)
|
52 |
+
print(" Device used = ", self.device)
|
53 |
+
|
54 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
55 |
+
|
56 |
+
# This BERT model will be saved every epoch
|
57 |
+
self.model = bertFinetunedClassifierwithFeats.to("cpu")
|
58 |
+
print(self.model.parameters())
|
59 |
+
for param in self.model.parameters():
|
60 |
+
param.requires_grad = False
|
61 |
+
# Initialize the BERT Language Model, with BERT model
|
62 |
+
# self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
63 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
64 |
+
# self.model = bertFinetunedClassifierwithFeats
|
65 |
+
# print(self.model.bert.parameters())
|
66 |
+
# for param in self.model.bert.parameters():
|
67 |
+
# param.requires_grad = False
|
68 |
+
# BERTForClassificationWithFeats(self.bert, num_labels, 18).to(self.device)
|
69 |
+
|
70 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
71 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
72 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
73 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
74 |
+
# self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
75 |
+
|
76 |
+
# Setting the train, validation and test data loader
|
77 |
+
# self.train_data = train_dataloader
|
78 |
+
# self.val_data = val_dataloader
|
79 |
+
self.test_data = test_dataloader
|
80 |
+
|
81 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
82 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
83 |
+
# self.optim_schedule = ScheduledOptim(self.optim, self.model.bert.hidden, n_warmup_steps=warmup_steps)
|
84 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
85 |
+
self.criterion = nn.CrossEntropyLoss()
|
86 |
+
|
87 |
+
# if num_labels == 1:
|
88 |
+
# self.criterion = nn.MSELoss()
|
89 |
+
# elif num_labels == 2:
|
90 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
91 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
92 |
+
# elif num_labels > 2:
|
93 |
+
# self.criterion = nn.CrossEntropyLoss()
|
94 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
95 |
+
|
96 |
+
|
97 |
+
self.log_freq = log_freq
|
98 |
+
self.log_folder_path = log_folder_path
|
99 |
+
# self.workspace_name = workspace_name
|
100 |
+
# self.finetune_task = finetune_task
|
101 |
+
# self.save_model = False
|
102 |
+
# self.avg_loss = 10000
|
103 |
+
self.start_time = time.time()
|
104 |
+
# self.probability_list = []
|
105 |
+
for fi in ['test']: #'val',
|
106 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
107 |
+
f.close()
|
108 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
109 |
+
|
110 |
+
# def train(self, epoch):
|
111 |
+
# self.iteration(epoch, self.train_data)
|
112 |
+
|
113 |
+
# def val(self, epoch):
|
114 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
115 |
+
|
116 |
+
def test(self, epoch):
|
117 |
+
# if epoch == 0:
|
118 |
+
# self.avg_loss = 10000
|
119 |
+
self.iteration(epoch, self.test_data, phase="test")
|
120 |
+
|
121 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
122 |
+
"""
|
123 |
+
loop over the data_loader for training or testing
|
124 |
+
if on train status, backward operation is activated
|
125 |
+
and also auto save the model every peoch
|
126 |
+
|
127 |
+
:param epoch: current epoch index
|
128 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
129 |
+
:param train: boolean value of is train or test
|
130 |
+
:return: None
|
131 |
+
"""
|
132 |
+
|
133 |
+
# Setting the tqdm progress bar
|
134 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
135 |
+
desc="EP_%s:%d" % (phase, epoch),
|
136 |
+
total=len(data_loader),
|
137 |
+
bar_format="{l_bar}{r_bar}")
|
138 |
+
|
139 |
+
avg_loss = 0.0
|
140 |
+
total_correct = 0
|
141 |
+
total_element = 0
|
142 |
+
plabels = []
|
143 |
+
tlabels = []
|
144 |
+
probabs = []
|
145 |
+
|
146 |
+
if phase == "train":
|
147 |
+
self.model.train()
|
148 |
+
else:
|
149 |
+
self.model.eval()
|
150 |
+
# self.probability_list = []
|
151 |
+
|
152 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
153 |
+
sys.stdout = f
|
154 |
+
for i, data in data_iter:
|
155 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
156 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
157 |
+
if phase == "train":
|
158 |
+
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
159 |
+
else:
|
160 |
+
with torch.no_grad():
|
161 |
+
logits = self.model.forward(data["input"].cpu(), data["segment_label"].cpu(), data["feat"].cpu())
|
162 |
+
|
163 |
+
logits = logits.cpu()
|
164 |
+
loss = self.criterion(logits, data["label"])
|
165 |
+
# if torch.cuda.device_count() > 1:
|
166 |
+
# loss = loss.mean()
|
167 |
+
|
168 |
+
# 3. backward and optimization only in train
|
169 |
+
# if phase == "train":
|
170 |
+
# self.optim_schedule.zero_grad()
|
171 |
+
# loss.backward()
|
172 |
+
# self.optim_schedule.step_and_update_lr()
|
173 |
+
|
174 |
+
# prediction accuracy
|
175 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
176 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
177 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
178 |
+
# self.probability_list.append(probs)
|
179 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
180 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
181 |
+
tlabels.extend(data['label'].cpu().numpy())
|
182 |
+
|
183 |
+
# Compare predicted labels to true labels and calculate accuracy
|
184 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
185 |
+
|
186 |
+
avg_loss += loss.item()
|
187 |
+
total_correct += correct
|
188 |
+
# total_element += true_labels.nelement()
|
189 |
+
total_element += data["label"].nelement()
|
190 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
191 |
+
|
192 |
+
post_fix = {
|
193 |
+
"epoch": epoch,
|
194 |
+
"iter": i,
|
195 |
+
"avg_loss": avg_loss / (i + 1),
|
196 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
197 |
+
"loss": loss.item()
|
198 |
+
}
|
199 |
+
if i % self.log_freq == 0:
|
200 |
+
data_iter.write(str(post_fix))
|
201 |
+
|
202 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
203 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
204 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
205 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
206 |
+
end_time = time.time()
|
207 |
+
auc_score = roc_auc_score(tlabels, plabels)
|
208 |
+
final_msg = {
|
209 |
+
"epoch": f"EP{epoch}_{phase}",
|
210 |
+
"avg_loss": avg_loss / len(data_iter),
|
211 |
+
"total_acc": total_correct * 100.0 / total_element,
|
212 |
+
"precisions": precisions,
|
213 |
+
"recalls": recalls,
|
214 |
+
"f1_scores": f1_scores,
|
215 |
+
# "confusion_matrix": f"{cmatrix}",
|
216 |
+
# "true_labels": f"{tlabels}",
|
217 |
+
# "predicted_labels": f"{plabels}",
|
218 |
+
"time_taken_from_start": end_time - self.start_time,
|
219 |
+
"auc_score":auc_score
|
220 |
+
}
|
221 |
+
with open("result.txt", 'w') as file:
|
222 |
+
for key, value in final_msg.items():
|
223 |
+
file.write(f"{key}: {value}\n")
|
224 |
+
print(final_msg)
|
225 |
+
fpr, tpr, thresholds = roc_curve(tlabels, plabels)
|
226 |
+
with open("roc_data.pkl", "wb") as f:
|
227 |
+
pickle.dump((fpr, tpr, thresholds), f)
|
228 |
+
print(final_msg)
|
229 |
+
f.close()
|
230 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
231 |
+
sys.stdout = f1
|
232 |
+
final_msg = {
|
233 |
+
"epoch": f"EP{epoch}_{phase}",
|
234 |
+
"confusion_matrix": f"{cmatrix}",
|
235 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
236 |
+
"predicted_labels": f"{plabels}",
|
237 |
+
"probabilities": f"{probabs}",
|
238 |
+
"time_taken_from_start": end_time - self.start_time
|
239 |
+
}
|
240 |
+
print(final_msg)
|
241 |
+
f1.close()
|
242 |
+
sys.stdout = sys.__stdout__
|
243 |
+
sys.stdout = sys.__stdout__
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
class BERTFineTuneCalibratedTrainer:
|
248 |
+
|
249 |
+
def __init__(self, bertFinetunedClassifierwithFeats: BERT, #BERTForClassificationWithFeats
|
250 |
+
vocab_size: int, test_dataloader: DataLoader = None,
|
251 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
252 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
253 |
+
num_labels=2, log_folder_path: str = None):
|
254 |
+
"""
|
255 |
+
:param bert: BERT model which you want to train
|
256 |
+
:param vocab_size: total word vocab size
|
257 |
+
:param test_dataloader: test dataset data loader [can be None]
|
258 |
+
:param lr: learning rate of optimizer
|
259 |
+
:param betas: Adam optimizer betas
|
260 |
+
:param weight_decay: Adam optimizer weight decay param
|
261 |
+
:param with_cuda: traning with cuda
|
262 |
+
:param log_freq: logging frequency of the batch iteration
|
263 |
+
"""
|
264 |
+
|
265 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
266 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
267 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
268 |
+
print(cuda_condition, " Device used = ", self.device)
|
269 |
+
|
270 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
271 |
+
|
272 |
+
# This BERT model will be saved every epoch
|
273 |
+
self.model = bertFinetunedClassifierwithFeats
|
274 |
+
print(self.model.parameters())
|
275 |
+
for param in self.model.parameters():
|
276 |
+
param.requires_grad = False
|
277 |
+
# Initialize the BERT Language Model, with BERT model
|
278 |
+
# self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
279 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
280 |
+
# self.model = bertFinetunedClassifierwithFeats
|
281 |
+
# print(self.model.bert.parameters())
|
282 |
+
# for param in self.model.bert.parameters():
|
283 |
+
# param.requires_grad = False
|
284 |
+
# BERTForClassificationWithFeats(self.bert, num_labels, 18).to(self.device)
|
285 |
+
|
286 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
287 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
288 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
289 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
290 |
+
# self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
291 |
+
|
292 |
+
# Setting the train, validation and test data loader
|
293 |
+
# self.train_data = train_dataloader
|
294 |
+
# self.val_data = val_dataloader
|
295 |
+
self.test_data = test_dataloader
|
296 |
+
|
297 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
298 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
299 |
+
# self.optim_schedule = ScheduledOptim(self.optim, self.model.bert.hidden, n_warmup_steps=warmup_steps)
|
300 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
301 |
+
self.criterion = nn.CrossEntropyLoss()
|
302 |
+
|
303 |
+
# if num_labels == 1:
|
304 |
+
# self.criterion = nn.MSELoss()
|
305 |
+
# elif num_labels == 2:
|
306 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
307 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
308 |
+
# elif num_labels > 2:
|
309 |
+
# self.criterion = nn.CrossEntropyLoss()
|
310 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
311 |
+
|
312 |
+
|
313 |
+
self.log_freq = log_freq
|
314 |
+
self.log_folder_path = log_folder_path
|
315 |
+
# self.workspace_name = workspace_name
|
316 |
+
# self.finetune_task = finetune_task
|
317 |
+
# self.save_model = False
|
318 |
+
# self.avg_loss = 10000
|
319 |
+
self.start_time = time.time()
|
320 |
+
# self.probability_list = []
|
321 |
+
for fi in ['test']: #'val',
|
322 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
323 |
+
f.close()
|
324 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
325 |
+
|
326 |
+
# def train(self, epoch):
|
327 |
+
# self.iteration(epoch, self.train_data)
|
328 |
+
|
329 |
+
# def val(self, epoch):
|
330 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
331 |
+
|
332 |
+
def test(self, epoch):
|
333 |
+
# if epoch == 0:
|
334 |
+
# self.avg_loss = 10000
|
335 |
+
self.iteration(epoch, self.test_data, phase="test")
|
336 |
+
|
337 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
338 |
+
"""
|
339 |
+
loop over the data_loader for training or testing
|
340 |
+
if on train status, backward operation is activated
|
341 |
+
and also auto save the model every peoch
|
342 |
+
|
343 |
+
:param epoch: current epoch index
|
344 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
345 |
+
:param train: boolean value of is train or test
|
346 |
+
:return: None
|
347 |
+
"""
|
348 |
+
|
349 |
+
# Setting the tqdm progress bar
|
350 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
351 |
+
desc="EP_%s:%d" % (phase, epoch),
|
352 |
+
total=len(data_loader),
|
353 |
+
bar_format="{l_bar}{r_bar}")
|
354 |
+
|
355 |
+
avg_loss = 0.0
|
356 |
+
total_correct = 0
|
357 |
+
total_element = 0
|
358 |
+
plabels = []
|
359 |
+
tlabels = []
|
360 |
+
probabs = []
|
361 |
+
|
362 |
+
if phase == "train":
|
363 |
+
self.model.train()
|
364 |
+
else:
|
365 |
+
self.model.eval()
|
366 |
+
# self.probability_list = []
|
367 |
+
|
368 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
369 |
+
sys.stdout = f
|
370 |
+
for i, data in data_iter:
|
371 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
372 |
+
# print(data_pair[0])
|
373 |
+
data = {key: value.to(self.device) for key, value in data[0].items()}
|
374 |
+
# print(f"data : {data}")
|
375 |
+
# data = {key: value.to(self.device) for key, value in data.items()}
|
376 |
+
|
377 |
+
# if phase == "train":
|
378 |
+
# logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
379 |
+
# else:
|
380 |
+
with torch.no_grad():
|
381 |
+
# logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
382 |
+
logits = self.model.forward(data)
|
383 |
+
|
384 |
+
loss = self.criterion(logits, data["label"])
|
385 |
+
if torch.cuda.device_count() > 1:
|
386 |
+
loss = loss.mean()
|
387 |
+
|
388 |
+
# 3. backward and optimization only in train
|
389 |
+
# if phase == "train":
|
390 |
+
# self.optim_schedule.zero_grad()
|
391 |
+
# loss.backward()
|
392 |
+
# self.optim_schedule.step_and_update_lr()
|
393 |
+
|
394 |
+
# prediction accuracy
|
395 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
396 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
397 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
398 |
+
# self.probability_list.append(probs)
|
399 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
400 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
401 |
+
tlabels.extend(data['label'].cpu().numpy())
|
402 |
+
positive_class_probs = [prob[1] for prob in probabs]
|
403 |
+
|
404 |
+
# Compare predicted labels to true labels and calculate accuracy
|
405 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
406 |
+
|
407 |
+
avg_loss += loss.item()
|
408 |
+
total_correct += correct
|
409 |
+
# total_element += true_labels.nelement()
|
410 |
+
total_element += data["label"].nelement()
|
411 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
412 |
+
|
413 |
+
post_fix = {
|
414 |
+
"epoch": epoch,
|
415 |
+
"iter": i,
|
416 |
+
"avg_loss": avg_loss / (i + 1),
|
417 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
418 |
+
"loss": loss.item()
|
419 |
+
}
|
420 |
+
if i % self.log_freq == 0:
|
421 |
+
data_iter.write(str(post_fix))
|
422 |
+
|
423 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
424 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
425 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
426 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
427 |
+
auc_score = roc_auc_score(tlabels, positive_class_probs)
|
428 |
+
end_time = time.time()
|
429 |
+
final_msg = {
|
430 |
+
"epoch": f"EP{epoch}_{phase}",
|
431 |
+
"avg_loss": avg_loss / len(data_iter),
|
432 |
+
"total_acc": total_correct * 100.0 / total_element,
|
433 |
+
"precisions": precisions,
|
434 |
+
"recalls": recalls,
|
435 |
+
"f1_scores": f1_scores,
|
436 |
+
"auc_score":auc_score,
|
437 |
+
# "confusion_matrix": f"{cmatrix}",
|
438 |
+
# "true_labels": f"{tlabels}",
|
439 |
+
# "predicted_labels": f"{plabels}",
|
440 |
+
"time_taken_from_start": end_time - self.start_time
|
441 |
+
}
|
442 |
+
with open("result.txt", 'w') as file:
|
443 |
+
for key, value in final_msg.items():
|
444 |
+
file.write(f"{key}: {value}\n")
|
445 |
+
|
446 |
+
print(final_msg)
|
447 |
+
fpr, tpr, thresholds = roc_curve(tlabels, positive_class_probs)
|
448 |
+
f.close()
|
449 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
450 |
+
sys.stdout = f1
|
451 |
+
final_msg = {
|
452 |
+
"epoch": f"EP{epoch}_{phase}",
|
453 |
+
"confusion_matrix": f"{cmatrix}",
|
454 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
455 |
+
"predicted_labels": f"{plabels}",
|
456 |
+
"probabilities": f"{probabs}",
|
457 |
+
"time_taken_from_start": end_time - self.start_time
|
458 |
+
}
|
459 |
+
print(final_msg)
|
460 |
+
f1.close()
|
461 |
+
sys.stdout = sys.__stdout__
|
462 |
+
sys.stdout = sys.__stdout__
|
463 |
+
|
464 |
+
|
465 |
+
|
466 |
+
def train():
|
467 |
+
parser = argparse.ArgumentParser()
|
468 |
+
|
469 |
+
parser.add_argument('-workspace_name', type=str, default=None)
|
470 |
+
parser.add_argument('-code', type=str, default=None, help="folder for pretraining outputs and logs")
|
471 |
+
parser.add_argument('-finetune_task', type=str, default=None, help="folder inside finetuning")
|
472 |
+
parser.add_argument("-attention", type=bool, default=False, help="analyse attention scores")
|
473 |
+
parser.add_argument("-diff_test_folder", type=bool, default=False, help="use for different test folder")
|
474 |
+
parser.add_argument("-embeddings", type=bool, default=False, help="get and analyse embeddings")
|
475 |
+
parser.add_argument('-embeddings_file_name', type=str, default=None, help="file name of embeddings")
|
476 |
+
parser.add_argument("-pretrain", type=bool, default=False, help="pretraining: true, or false")
|
477 |
+
# parser.add_argument('-opts', nargs='+', type=str, default=None, help='List of optional steps')
|
478 |
+
parser.add_argument("-max_mask", type=int, default=0.15, help="% of input tokens selected for masking")
|
479 |
+
# parser.add_argument("-p", "--pretrain_dataset", type=str, default="pretraining/pretrain.txt", help="pretraining dataset for bert")
|
480 |
+
# parser.add_argument("-pv", "--pretrain_val_dataset", type=str, default="pretraining/test.txt", help="pretraining validation dataset for bert")
|
481 |
+
# default="finetuning/test.txt",
|
482 |
+
parser.add_argument("-vocab_path", type=str, default="pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
483 |
+
|
484 |
+
parser.add_argument("-train_dataset_path", type=str, default="train.txt", help="fine tune train dataset for progress classifier")
|
485 |
+
parser.add_argument("-val_dataset_path", type=str, default="val.txt", help="test set for evaluate fine tune train set")
|
486 |
+
parser.add_argument("-test_dataset_path", type=str, default="test.txt", help="test set for evaluate fine tune train set")
|
487 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
488 |
+
parser.add_argument("-train_label_path", type=str, default="train_label.txt", help="fine tune train dataset for progress classifier")
|
489 |
+
parser.add_argument("-val_label_path", type=str, default="val_label.txt", help="test set for evaluate fine tune train set")
|
490 |
+
parser.add_argument("-test_label_path", type=str, default="test_label.txt", help="test set for evaluate fine tune train set")
|
491 |
+
##### change Checkpoint for finetuning
|
492 |
+
parser.add_argument("-pretrained_bert_checkpoint", type=str, default=None, help="checkpoint of saved pretrained bert model")
|
493 |
+
parser.add_argument("-finetuned_bert_classifier_checkpoint", type=str, default=None, help="checkpoint of saved finetuned bert model") #."output_feb09/bert_trained.model.ep40"
|
494 |
+
#."output_feb09/bert_trained.model.ep40"
|
495 |
+
parser.add_argument('-check_epoch', type=int, default=None)
|
496 |
+
|
497 |
+
parser.add_argument("-hs", "--hidden", type=int, default=64, help="hidden size of transformer model") #64
|
498 |
+
parser.add_argument("-l", "--layers", type=int, default=4, help="number of layers") #4
|
499 |
+
parser.add_argument("-a", "--attn_heads", type=int, default=4, help="number of attention heads") #8
|
500 |
+
parser.add_argument("-s", "--seq_len", type=int, default=5, help="maximum sequence length")
|
501 |
+
|
502 |
+
parser.add_argument("-b", "--batch_size", type=int, default=500, help="number of batch_size") #64
|
503 |
+
parser.add_argument("-e", "--epochs", type=int, default=1)#1501, help="number of epochs") #501
|
504 |
+
# Use 50 for pretrain, and 10 for fine tune
|
505 |
+
parser.add_argument("-w", "--num_workers", type=int, default=0, help="dataloader worker size")
|
506 |
+
|
507 |
+
# Later run with cuda
|
508 |
+
parser.add_argument("--with_cuda", type=bool, default=False, help="training with CUDA: true, or false")
|
509 |
+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
|
510 |
+
# parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
|
511 |
+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
|
512 |
+
# parser.add_argument("--on_memory", type=bool, default=False, help="Loading on memory: true or false")
|
513 |
+
|
514 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout of network")
|
515 |
+
parser.add_argument("--lr", type=float, default=1e-05, help="learning rate of adam") #1e-3
|
516 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
|
517 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
|
518 |
+
parser.add_argument("--adam_beta2", type=float, default=0.98, help="adam first beta value") #0.999
|
519 |
+
|
520 |
+
parser.add_argument("-o", "--output_path", type=str, default="bert_trained.seq_encoder.model", help="ex)output/bert.model")
|
521 |
+
# parser.add_argument("-o", "--output_path", type=str, default="output/bert_fine_tuned.model", help="ex)output/bert.model")
|
522 |
+
|
523 |
+
args = parser.parse_args()
|
524 |
+
for k,v in vars(args).items():
|
525 |
+
if 'path' in k:
|
526 |
+
if v:
|
527 |
+
if k == "output_path":
|
528 |
+
if args.code:
|
529 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.code}/"+v)
|
530 |
+
elif args.finetune_task:
|
531 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.finetune_task}/"+v)
|
532 |
+
else:
|
533 |
+
setattr(args, f"{k}", args.workspace_name+"/output/"+v)
|
534 |
+
elif k != "vocab_path":
|
535 |
+
if args.pretrain:
|
536 |
+
setattr(args, f"{k}", args.workspace_name+"/pretraining/"+v)
|
537 |
+
else:
|
538 |
+
if args.code:
|
539 |
+
setattr(args, f"{k}", args.workspace_name+f"/{args.code}/"+v)
|
540 |
+
elif args.finetune_task:
|
541 |
+
if args.diff_test_folder and "test" in k:
|
542 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/"+v)
|
543 |
+
else:
|
544 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/{args.finetune_task}/"+v)
|
545 |
+
else:
|
546 |
+
setattr(args, f"{k}", args.workspace_name+"/finetuning/"+v)
|
547 |
+
else:
|
548 |
+
setattr(args, f"{k}", args.workspace_name+"/"+v)
|
549 |
+
|
550 |
+
print(f"args.{k} : {getattr(args, f'{k}')}")
|
551 |
+
|
552 |
+
print("Loading Vocab", args.vocab_path)
|
553 |
+
vocab_obj = Vocab(args.vocab_path)
|
554 |
+
vocab_obj.load_vocab()
|
555 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
556 |
+
|
557 |
+
|
558 |
+
print("Testing using finetuned model......")
|
559 |
+
print("Loading Test Dataset", args.test_dataset_path)
|
560 |
+
test_dataset = TokenizerDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
561 |
+
# test_dataset = TokenizerDatasetForCalibration(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
562 |
+
|
563 |
+
print("Creating Dataloader...")
|
564 |
+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
565 |
+
|
566 |
+
print("Load fine-tuned BERT classifier model with feats")
|
567 |
+
# cuda_condition = torch.cuda.is_available() and args.with_cuda
|
568 |
+
device = torch.device("cpu") #torch.device("cuda:0" if cuda_condition else "cpu")
|
569 |
+
finetunedBERTclassifier = torch.load(args.finetuned_bert_classifier_checkpoint, map_location=device)
|
570 |
+
if isinstance(finetunedBERTclassifier, torch.nn.DataParallel):
|
571 |
+
finetunedBERTclassifier = finetunedBERTclassifier.module
|
572 |
+
|
573 |
+
new_log_folder = f"{args.workspace_name}/logs"
|
574 |
+
new_output_folder = f"{args.workspace_name}/output"
|
575 |
+
if args.finetune_task: # is sent almost all the time
|
576 |
+
new_log_folder = f"{args.workspace_name}/logs/{args.finetune_task}"
|
577 |
+
new_output_folder = f"{args.workspace_name}/output/{args.finetune_task}"
|
578 |
+
|
579 |
+
if not os.path.exists(new_log_folder):
|
580 |
+
os.makedirs(new_log_folder)
|
581 |
+
if not os.path.exists(new_output_folder):
|
582 |
+
os.makedirs(new_output_folder)
|
583 |
+
|
584 |
+
print("Creating BERT Fine Tuned Test Trainer")
|
585 |
+
trainer = BERTFineTuneTrainer(finetunedBERTclassifier,
|
586 |
+
len(vocab_obj.vocab), test_dataloader=test_data_loader,
|
587 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
588 |
+
with_cuda=args.with_cuda, cuda_devices = args.cuda_devices, log_freq=args.log_freq,
|
589 |
+
workspace_name = args.workspace_name, num_labels=args.num_labels, log_folder_path=new_log_folder)
|
590 |
+
|
591 |
+
# trainer = BERTFineTuneCalibratedTrainer(finetunedBERTclassifier,
|
592 |
+
# len(vocab_obj.vocab), test_dataloader=test_data_loader,
|
593 |
+
# lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
594 |
+
# with_cuda=args.with_cuda, cuda_devices = args.cuda_devices, log_freq=args.log_freq,
|
595 |
+
# workspace_name = args.workspace_name, num_labels=args.num_labels, log_folder_path=new_log_folder)
|
596 |
+
print("Testing fine-tuned model Start....")
|
597 |
+
start_time = time.time()
|
598 |
+
repoch = range(args.check_epoch, args.epochs) if args.check_epoch else range(args.epochs)
|
599 |
+
counter = 0
|
600 |
+
# patience = 10
|
601 |
+
for epoch in repoch:
|
602 |
+
print(f'Test Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
603 |
+
trainer.test(epoch)
|
604 |
+
# pickle.dump(trainer.probability_list, open(f"{args.workspace_name}/output/aaai/change4_mid_prob_{epoch}.pkl","wb"))
|
605 |
+
print(f'Test Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
606 |
+
end_time = time.time()
|
607 |
+
print("Time Taken to fine-tune model = ", end_time - start_time)
|
608 |
+
print(f'Pretraining Ends, Time: {time.strftime("%D %T", time.localtime(end_time))}')
|
609 |
+
|
610 |
+
|
611 |
+
|
612 |
+
if __name__ == "__main__":
|
613 |
+
train()
|
plot.png
CHANGED
prepare_pretraining_input_vocab_file.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ratio_proportion_change3_2223/sch_largest_100-coded/pretraining/vocab.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[PAD]
|
2 |
+
[UNK]
|
3 |
+
[MASK]
|
4 |
+
[CLS]
|
5 |
+
[SEP]
|
6 |
+
DenominatorFactor
|
7 |
+
DenominatorQuantity1-0
|
8 |
+
DenominatorQuantity1-1
|
9 |
+
DenominatorQuantity1-2
|
10 |
+
EquationAnswer
|
11 |
+
FinalAnswer-0
|
12 |
+
FinalAnswer-1
|
13 |
+
FinalAnswer-2
|
14 |
+
FinalAnswerDirection-0
|
15 |
+
FinalAnswerDirection-1
|
16 |
+
FinalAnswerDirection-2
|
17 |
+
FirstRow1:1
|
18 |
+
FirstRow1:2
|
19 |
+
FirstRow2:1
|
20 |
+
FirstRow2:2
|
21 |
+
NumeratorFactor
|
22 |
+
NumeratorQuantity1-0
|
23 |
+
NumeratorQuantity1-1
|
24 |
+
NumeratorQuantity1-2
|
25 |
+
NumeratorQuantity2-0
|
26 |
+
NumeratorQuantity2-1
|
27 |
+
NumeratorQuantity2-2
|
28 |
+
OptionalTask_1
|
29 |
+
OptionalTask_2
|
30 |
+
PercentChange-0
|
31 |
+
PercentChange-1
|
32 |
+
PercentChange-2
|
33 |
+
SecondRow
|
34 |
+
ThirdRow
|
recalibration.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, optim
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
import metrics
|
6 |
+
|
7 |
+
class ModelWithTemperature(nn.Module):
|
8 |
+
"""
|
9 |
+
A thin decorator, which wraps a model with temperature scaling
|
10 |
+
model (nn.Module):
|
11 |
+
A classification neural network
|
12 |
+
NB: Output of the neural network should be the classification logits,
|
13 |
+
NOT the softmax (or log softmax)!
|
14 |
+
"""
|
15 |
+
def __init__(self, model, device="cpu"):
|
16 |
+
super(ModelWithTemperature, self).__init__()
|
17 |
+
self.model = model
|
18 |
+
self.device = torch.device(device)
|
19 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
20 |
+
|
21 |
+
def forward(self, input):
|
22 |
+
logits = self.model(input["input"], input["segment_label"], input["feat"])
|
23 |
+
return self.temperature_scale(logits)
|
24 |
+
|
25 |
+
def temperature_scale(self, logits):
|
26 |
+
"""
|
27 |
+
Perform temperature scaling on logits
|
28 |
+
"""
|
29 |
+
# Expand temperature to match the size of logits
|
30 |
+
temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)).to(self.device)
|
31 |
+
return logits / temperature
|
32 |
+
|
33 |
+
# This function probably should live outside of this class, but whatever
|
34 |
+
def set_temperature(self, valid_loader):
|
35 |
+
"""
|
36 |
+
Tune the tempearature of the model (using the validation set).
|
37 |
+
We're going to set it to optimize NLL.
|
38 |
+
valid_loader (DataLoader): validation set loader
|
39 |
+
"""
|
40 |
+
#self.cuda()
|
41 |
+
nll_criterion = nn.CrossEntropyLoss()
|
42 |
+
ece_criterion = metrics.ECELoss()
|
43 |
+
|
44 |
+
# First: collect all the logits and labels for the validation set
|
45 |
+
logits_list = []
|
46 |
+
labels_list = []
|
47 |
+
with torch.no_grad():
|
48 |
+
for input, label in valid_loader:
|
49 |
+
# print("Input = ", input["input"])
|
50 |
+
# print("Input = ", input["segment_label"])
|
51 |
+
# print("Input = ", input["feat"])
|
52 |
+
# input = input
|
53 |
+
logits = self.model(input["input"].to(self.device), input["segment_label"].to(self.device), input["feat"].to(self.device))
|
54 |
+
logits_list.append(logits)
|
55 |
+
labels_list.append(label)
|
56 |
+
logits = torch.cat(logits_list).to(self.device)
|
57 |
+
labels = torch.cat(labels_list).to(self.device)
|
58 |
+
|
59 |
+
# Calculate NLL and ECE before temperature scaling
|
60 |
+
before_temperature_nll = nll_criterion(logits, labels).item()
|
61 |
+
before_temperature_ece = ece_criterion.loss(logits.cpu().numpy(),labels.cpu().numpy(),15)
|
62 |
+
#before_temperature_ece = ece_criterion(logits, labels).item()
|
63 |
+
#ece_2 = ece_criterion_2.loss(logits,labels)
|
64 |
+
print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))
|
65 |
+
#print(ece_2)
|
66 |
+
# Next: optimize the temperature w.r.t. NLL
|
67 |
+
optimizer = optim.LBFGS([self.temperature], lr=0.005, max_iter=1000)
|
68 |
+
|
69 |
+
def eval():
|
70 |
+
loss = nll_criterion(self.temperature_scale(logits.to(self.device)), labels.to(self.device))
|
71 |
+
loss.backward()
|
72 |
+
return loss
|
73 |
+
optimizer.step(eval)
|
74 |
+
|
75 |
+
# Calculate NLL and ECE after temperature scaling
|
76 |
+
after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
|
77 |
+
after_temperature_ece = ece_criterion.loss(self.temperature_scale(logits).detach().cpu().numpy(),labels.cpu().numpy(),15)
|
78 |
+
#after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
|
79 |
+
print('Optimal temperature: %.3f' % self.temperature.item())
|
80 |
+
print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
|
81 |
+
|
82 |
+
return self
|
src/__pycache__/attention.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/attention.cpython-312.pyc and b/src/__pycache__/attention.cpython-312.pyc differ
|
|
src/__pycache__/bert.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/bert.cpython-312.pyc and b/src/__pycache__/bert.cpython-312.pyc differ
|
|
src/__pycache__/classifier_model.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/classifier_model.cpython-312.pyc and b/src/__pycache__/classifier_model.cpython-312.pyc differ
|
|
src/__pycache__/dataset.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/dataset.cpython-312.pyc and b/src/__pycache__/dataset.cpython-312.pyc differ
|
|
src/__pycache__/embedding.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/embedding.cpython-312.pyc and b/src/__pycache__/embedding.cpython-312.pyc differ
|
|
src/__pycache__/seq_model.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/seq_model.cpython-312.pyc and b/src/__pycache__/seq_model.cpython-312.pyc differ
|
|
src/__pycache__/transformer.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/transformer.cpython-312.pyc and b/src/__pycache__/transformer.cpython-312.pyc differ
|
|
src/__pycache__/transformer_component.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/transformer_component.cpython-312.pyc and b/src/__pycache__/transformer_component.cpython-312.pyc differ
|
|
src/__pycache__/vocab.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/vocab.cpython-312.pyc and b/src/__pycache__/vocab.cpython-312.pyc differ
|
|
src/attention.py
CHANGED
@@ -3,11 +3,19 @@ import torch.nn.functional as F
|
|
3 |
import torch
|
4 |
|
5 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class Attention(nn.Module):
|
9 |
"""
|
10 |
Compute 'Scaled Dot Product Attention
|
|
|
11 |
"""
|
12 |
|
13 |
def __init__(self):
|
@@ -45,7 +53,10 @@ class MultiHeadedAttention(nn.Module):
|
|
45 |
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
|
46 |
self.output_linear = nn.Linear(d_model, d_model)
|
47 |
self.attention = Attention()
|
|
|
|
|
48 |
|
|
|
49 |
self.dropout = nn.Dropout(p=dropout)
|
50 |
|
51 |
def forward(self, query, key, value, mask=None):
|
@@ -59,6 +70,14 @@ class MultiHeadedAttention(nn.Module):
|
|
59 |
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
60 |
for l, x in zip(self.linear_layers, (query, key, value))]
|
61 |
# 2) Apply attention on all the projected vectors in batch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
63 |
# torch.Size([64, 8, 100, 100])
|
64 |
# print("Attention", attn.shape)
|
@@ -67,4 +86,5 @@ class MultiHeadedAttention(nn.Module):
|
|
67 |
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
68 |
|
69 |
return self.output_linear(x)
|
70 |
-
|
|
|
|
3 |
import torch
|
4 |
|
5 |
import math
|
6 |
+
<<<<<<< HEAD
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
class Attention(nn.Module):
|
10 |
+
"""
|
11 |
+
Compute Scaled Dot Product Attention
|
12 |
+
=======
|
13 |
|
14 |
|
15 |
class Attention(nn.Module):
|
16 |
"""
|
17 |
Compute 'Scaled Dot Product Attention
|
18 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
19 |
"""
|
20 |
|
21 |
def __init__(self):
|
|
|
53 |
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
|
54 |
self.output_linear = nn.Linear(d_model, d_model)
|
55 |
self.attention = Attention()
|
56 |
+
<<<<<<< HEAD
|
57 |
+
=======
|
58 |
|
59 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
60 |
self.dropout = nn.Dropout(p=dropout)
|
61 |
|
62 |
def forward(self, query, key, value, mask=None):
|
|
|
70 |
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
71 |
for l, x in zip(self.linear_layers, (query, key, value))]
|
72 |
# 2) Apply attention on all the projected vectors in batch.
|
73 |
+
<<<<<<< HEAD
|
74 |
+
x, p_attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
75 |
+
|
76 |
+
# 3) "Concat" using a view and apply a final linear.
|
77 |
+
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
78 |
+
|
79 |
+
return self.output_linear(x), p_attn
|
80 |
+
=======
|
81 |
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
82 |
# torch.Size([64, 8, 100, 100])
|
83 |
# print("Attention", attn.shape)
|
|
|
86 |
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
87 |
|
88 |
return self.output_linear(x)
|
89 |
+
|
90 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/bert.py
CHANGED
@@ -1,7 +1,14 @@
|
|
1 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from transformer import TransformerBlock
|
4 |
from embedding import BERTEmbedding
|
|
|
5 |
|
6 |
class BERT(nn.Module):
|
7 |
"""
|
@@ -31,10 +38,37 @@ class BERT(nn.Module):
|
|
31 |
# multi-layers transformer blocks, deep network
|
32 |
self.transformer_blocks = nn.ModuleList(
|
33 |
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def forward(self, x, segment_info):
|
36 |
# attention masking for padded token
|
37 |
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
39 |
# print("bert mask: ", mask)
|
40 |
# embedding the indexed sequence to sequence of vectors
|
@@ -43,5 +77,6 @@ class BERT(nn.Module):
|
|
43 |
# running over multiple transformer blocks
|
44 |
for transformer in self.transformer_blocks:
|
45 |
x = transformer.forward(x, mask)
|
|
|
46 |
|
47 |
return x
|
|
|
1 |
import torch.nn as nn
|
2 |
+
<<<<<<< HEAD
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .transformer import TransformerBlock
|
6 |
+
from .embedding import BERTEmbedding
|
7 |
+
=======
|
8 |
|
9 |
from transformer import TransformerBlock
|
10 |
from embedding import BERTEmbedding
|
11 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
12 |
|
13 |
class BERT(nn.Module):
|
14 |
"""
|
|
|
38 |
# multi-layers transformer blocks, deep network
|
39 |
self.transformer_blocks = nn.ModuleList(
|
40 |
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
|
41 |
+
<<<<<<< HEAD
|
42 |
+
# self.attention_values = []
|
43 |
+
=======
|
44 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
45 |
|
46 |
def forward(self, x, segment_info):
|
47 |
# attention masking for padded token
|
48 |
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
|
49 |
+
<<<<<<< HEAD
|
50 |
+
|
51 |
+
device = x.device
|
52 |
+
|
53 |
+
masked = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)
|
54 |
+
r,e,c = masked.shape
|
55 |
+
mask = torch.zeros((r, e, c), dtype=torch.bool).to(device=device)
|
56 |
+
|
57 |
+
for i in range(r):
|
58 |
+
mask[i] = masked[i].T*masked[i]
|
59 |
+
mask = mask.unsqueeze(1)
|
60 |
+
# mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
61 |
+
|
62 |
+
# print("bert mask: ", mask)
|
63 |
+
# embedding the indexed sequence to sequence of vectors
|
64 |
+
x = self.embedding(x, segment_info)
|
65 |
+
|
66 |
+
# self.attention_values = []
|
67 |
+
# running over multiple transformer blocks
|
68 |
+
for transformer in self.transformer_blocks:
|
69 |
+
x = transformer.forward(x, mask)
|
70 |
+
# self.attention_values.append(transformer.p_attn)
|
71 |
+
=======
|
72 |
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
73 |
# print("bert mask: ", mask)
|
74 |
# embedding the indexed sequence to sequence of vectors
|
|
|
77 |
# running over multiple transformer blocks
|
78 |
for transformer in self.transformer_blocks:
|
79 |
x = transformer.forward(x, mask)
|
80 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
81 |
|
82 |
return x
|
src/classifier_model.py
CHANGED
@@ -1,16 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch.nn as nn
|
2 |
|
3 |
from bert import BERT
|
|
|
4 |
|
5 |
|
6 |
class BERTForClassification(nn.Module):
|
7 |
"""
|
|
|
|
|
|
|
8 |
Progress Classifier Model
|
|
|
9 |
"""
|
10 |
|
11 |
def __init__(self, bert: BERT, vocab_size, n_labels):
|
12 |
"""
|
13 |
:param bert: BERT model which should be trained
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
:param vocab_size: total vocab size for masked_lm
|
15 |
"""
|
16 |
|
@@ -21,4 +71,5 @@ class BERTForClassification(nn.Module):
|
|
21 |
|
22 |
def forward(self, x, segment_label):
|
23 |
x = self.bert(x, segment_label)
|
24 |
-
return x, self.linear(x[:, 0])
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .bert import BERT
|
6 |
+
=======
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from bert import BERT
|
10 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
11 |
|
12 |
|
13 |
class BERTForClassification(nn.Module):
|
14 |
"""
|
15 |
+
<<<<<<< HEAD
|
16 |
+
Fine-tune Task Classifier Model
|
17 |
+
=======
|
18 |
Progress Classifier Model
|
19 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
20 |
"""
|
21 |
|
22 |
def __init__(self, bert: BERT, vocab_size, n_labels):
|
23 |
"""
|
24 |
:param bert: BERT model which should be trained
|
25 |
+
<<<<<<< HEAD
|
26 |
+
:param vocab_size: total vocab size
|
27 |
+
:param n_labels: number of labels for the task
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.bert = bert
|
31 |
+
self.linear = nn.Linear(self.bert.hidden, n_labels)
|
32 |
+
|
33 |
+
def forward(self, x, segment_label):
|
34 |
+
x = self.bert(x, segment_label)
|
35 |
+
return self.linear(x[:, 0])
|
36 |
+
|
37 |
+
class BERTForClassificationWithFeats(nn.Module):
|
38 |
+
"""
|
39 |
+
Fine-tune Task Classifier Model
|
40 |
+
BERT embeddings concatenated with features
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, bert: BERT, n_labels, feat_size=9):
|
44 |
+
"""
|
45 |
+
:param bert: BERT model which should be trained
|
46 |
+
:param vocab_size: total vocab size
|
47 |
+
:param n_labels: number of labels for the task
|
48 |
+
"""
|
49 |
+
super().__init__()
|
50 |
+
self.bert = bert
|
51 |
+
# self.linear1 = nn.Linear(self.bert.hidden+feat_size, 128)
|
52 |
+
self.linear = nn.Linear(self.bert.hidden+feat_size, n_labels)
|
53 |
+
# self.RELU = nn.ReLU()
|
54 |
+
# self.linear2 = nn.Linear(128, n_labels)
|
55 |
+
|
56 |
+
def forward(self, x, segment_label, feat):
|
57 |
+
x = self.bert(x, segment_label)
|
58 |
+
x = torch.cat((x[:, 0], feat), dim=-1)
|
59 |
+
# x = self.linear1(x)
|
60 |
+
# x = self.RELU(x)
|
61 |
+
# return self.linear2(x)
|
62 |
+
return self.linear(x)
|
63 |
+
=======
|
64 |
:param vocab_size: total vocab size for masked_lm
|
65 |
"""
|
66 |
|
|
|
71 |
|
72 |
def forward(self, x, segment_label):
|
73 |
x = self.bert(x, segment_label)
|
74 |
+
return x, self.linear(x[:, 0])
|
75 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/dataset.py
CHANGED
@@ -4,17 +4,28 @@ import pandas as pd
|
|
4 |
import numpy as np
|
5 |
import tqdm
|
6 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from vocab import Vocab
|
8 |
import pickle
|
9 |
import copy
|
10 |
from sklearn.preprocessing import OneHotEncoder
|
|
|
11 |
|
12 |
class PretrainerDataset(Dataset):
|
13 |
"""
|
14 |
Class name: PretrainDataset
|
15 |
|
16 |
"""
|
|
|
|
|
|
|
17 |
def __init__(self, dataset_path, vocab, seq_len=30, select_next_seq= False):
|
|
|
18 |
self.dataset_path = dataset_path
|
19 |
self.vocab = vocab # Vocab object
|
20 |
|
@@ -35,6 +46,22 @@ class PretrainerDataset(Dataset):
|
|
35 |
self.index_documents[i] = []
|
36 |
else:
|
37 |
self.index_documents[i].append(index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
self.lines.append(line.split())
|
39 |
len_line = len(line.split())
|
40 |
seq_len_list.append(len_line)
|
@@ -49,6 +76,7 @@ class PretrainerDataset(Dataset):
|
|
49 |
print("Sequence length set at ", self.seq_len)
|
50 |
print("select_next_seq: ", self.select_next_seq)
|
51 |
print(len(self.index_documents))
|
|
|
52 |
|
53 |
|
54 |
def __len__(self):
|
@@ -56,6 +84,53 @@ class PretrainerDataset(Dataset):
|
|
56 |
|
57 |
def __getitem__(self, item):
|
58 |
token_a = self.lines[item]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
token_b = None
|
60 |
is_same_student = None
|
61 |
sa_masked = None
|
@@ -92,6 +167,7 @@ class PretrainerDataset(Dataset):
|
|
92 |
if self.select_next_seq:
|
93 |
output['is_same_student'] = is_same_student
|
94 |
# print(item, len(s1), len(s1_label), len(segment_label))
|
|
|
95 |
return {key: torch.tensor(value) for key, value in output.items()}
|
96 |
|
97 |
def random_mask_seq(self, tokens):
|
@@ -100,6 +176,28 @@ class PretrainerDataset(Dataset):
|
|
100 |
Output: masked token seq, output label
|
101 |
"""
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# masked_pos_label = {}
|
104 |
output_labels = []
|
105 |
output_tokens = copy.deepcopy(tokens)
|
@@ -108,17 +206,34 @@ class PretrainerDataset(Dataset):
|
|
108 |
for i, token in enumerate(tokens):
|
109 |
prob = random.random()
|
110 |
if prob < 0.15:
|
|
|
111 |
# chooses 15% of token positions at random
|
112 |
# prob /= 0.15
|
113 |
prob = random.random()
|
114 |
if prob < 0.8: #[MASK] token 80% of the time
|
115 |
output_tokens[i] = self.vocab.vocab['[MASK]']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
elif prob < 0.9: # a random token 10% of the time
|
117 |
# print(".......0.8-0.9......")
|
118 |
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
|
119 |
else: # the unchanged i-th token 10% of the time
|
120 |
# print(".......unchanged......")
|
121 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
|
|
122 |
# True Label
|
123 |
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
|
124 |
# masked_pos_label[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
@@ -127,11 +242,53 @@ class PretrainerDataset(Dataset):
|
|
127 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
128 |
# Padded label
|
129 |
output_labels.append(self.vocab.vocab['[PAD]'])
|
|
|
|
|
|
|
|
|
130 |
# label_position = []
|
131 |
# label_tokens = []
|
132 |
# for k, v in masked_pos_label.items():
|
133 |
# label_position.append(k)
|
134 |
# label_tokens.append(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return output_tokens, output_labels
|
136 |
|
137 |
def get_token_b(self, item):
|
@@ -167,6 +324,7 @@ class PretrainerDataset(Dataset):
|
|
167 |
else:
|
168 |
sb.pop()
|
169 |
return sa, sb
|
|
|
170 |
|
171 |
class TokenizerDataset(Dataset):
|
172 |
"""
|
@@ -174,15 +332,89 @@ class TokenizerDataset(Dataset):
|
|
174 |
Tokenize the data in the dataset
|
175 |
|
176 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
def __init__(self, dataset_path, label_path, vocab, seq_len=30, train=True):
|
178 |
self.dataset_path = dataset_path
|
179 |
self.label_path = label_path
|
180 |
self.vocab = vocab # Vocab object
|
181 |
self.encoder = OneHotEncoder(sparse_output=False)
|
|
|
182 |
|
183 |
# Related to input dataset file
|
184 |
self.lines = []
|
185 |
self.labels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
self.labels = []
|
187 |
|
188 |
self.label_file = open(self.label_path, "r")
|
@@ -234,11 +466,14 @@ class TokenizerDataset(Dataset):
|
|
234 |
|
235 |
self.file = open(self.dataset_path, "r")
|
236 |
# index = 0
|
|
|
237 |
for line in self.file:
|
238 |
if line:
|
239 |
line = line.strip()
|
240 |
if line:
|
241 |
self.lines.append(line)
|
|
|
|
|
242 |
# if train:
|
243 |
# if index in indices_of_zeros:
|
244 |
# # if index in indices_of_prom:
|
@@ -253,17 +488,46 @@ class TokenizerDataset(Dataset):
|
|
253 |
# self.labels.append(labels[index])
|
254 |
# self.labels.append(progress[index])
|
255 |
# index += 1
|
|
|
256 |
self.file.close()
|
257 |
|
258 |
self.len = len(self.lines)
|
259 |
self.seq_len = seq_len
|
|
|
|
|
|
|
260 |
|
261 |
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels))
|
|
|
262 |
|
263 |
def __len__(self):
|
264 |
return self.len
|
265 |
|
266 |
def __getitem__(self, item):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
s1 = self.vocab.to_seq(self.lines[item], self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
269 |
s1_label = self.labels[item]
|
@@ -274,11 +538,132 @@ class TokenizerDataset(Dataset):
|
|
274 |
|
275 |
output = {'bert_input': s1,
|
276 |
'progress_status': s1_label,
|
|
|
277 |
'segment_label': segment_label}
|
278 |
return {key: torch.tensor(value) for key, value in output.items()}
|
279 |
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
# if __name__ == "__main__":
|
|
|
282 |
# # import pickle
|
283 |
# # k = pickle.load(open("dataset/CL4999_1920/unique_steps_list.pkl","rb"))
|
284 |
# # print(k)
|
|
|
4 |
import numpy as np
|
5 |
import tqdm
|
6 |
import random
|
7 |
+
<<<<<<< HEAD
|
8 |
+
from .vocab import Vocab
|
9 |
+
import pickle
|
10 |
+
import copy
|
11 |
+
# from sklearn.preprocessing import OneHotEncoder
|
12 |
+
=======
|
13 |
from vocab import Vocab
|
14 |
import pickle
|
15 |
import copy
|
16 |
from sklearn.preprocessing import OneHotEncoder
|
17 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
18 |
|
19 |
class PretrainerDataset(Dataset):
|
20 |
"""
|
21 |
Class name: PretrainDataset
|
22 |
|
23 |
"""
|
24 |
+
<<<<<<< HEAD
|
25 |
+
def __init__(self, dataset_path, vocab, seq_len=30, max_mask=0.15):
|
26 |
+
=======
|
27 |
def __init__(self, dataset_path, vocab, seq_len=30, select_next_seq= False):
|
28 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
29 |
self.dataset_path = dataset_path
|
30 |
self.vocab = vocab # Vocab object
|
31 |
|
|
|
46 |
self.index_documents[i] = []
|
47 |
else:
|
48 |
self.index_documents[i].append(index)
|
49 |
+
<<<<<<< HEAD
|
50 |
+
self.lines.append(line.split("\t"))
|
51 |
+
len_line = len(line.split("\t"))
|
52 |
+
seq_len_list.append(len_line)
|
53 |
+
index+=1
|
54 |
+
reader.close()
|
55 |
+
print("Sequence Stats: len: %s, min: %s, max: %s, average: %s"% (len(seq_len_list),
|
56 |
+
min(seq_len_list), max(seq_len_list), sum(seq_len_list)/len(seq_len_list)))
|
57 |
+
print("Unique Sequences: ", len({tuple(ll) for ll in self.lines}))
|
58 |
+
self.index_documents = {k:v for k,v in self.index_documents.items() if v}
|
59 |
+
print(len(self.index_documents))
|
60 |
+
self.seq_len = seq_len
|
61 |
+
print("Sequence length set at: ", self.seq_len)
|
62 |
+
self.max_mask = max_mask
|
63 |
+
print("% of input tokens selected for masking : ",self.max_mask)
|
64 |
+
=======
|
65 |
self.lines.append(line.split())
|
66 |
len_line = len(line.split())
|
67 |
seq_len_list.append(len_line)
|
|
|
76 |
print("Sequence length set at ", self.seq_len)
|
77 |
print("select_next_seq: ", self.select_next_seq)
|
78 |
print(len(self.index_documents))
|
79 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
80 |
|
81 |
|
82 |
def __len__(self):
|
|
|
84 |
|
85 |
def __getitem__(self, item):
|
86 |
token_a = self.lines[item]
|
87 |
+
<<<<<<< HEAD
|
88 |
+
# sa_masked = None
|
89 |
+
# sa_masked_label = None
|
90 |
+
# token_b = None
|
91 |
+
# is_same_student = None
|
92 |
+
# sb_masked = None
|
93 |
+
# sb_masked_label = None
|
94 |
+
|
95 |
+
# if self.select_next_seq:
|
96 |
+
# is_same_student, token_b = self.get_token_b(item)
|
97 |
+
# is_same_student = 1 if is_same_student else 0
|
98 |
+
# token_a1, token_b1 = self.truncate_to_max_seq(token_a, token_b)
|
99 |
+
# sa_masked, sa_masked_label = self.random_mask_seq(token_a1)
|
100 |
+
# sb_masked, sb_masked_label = self.random_mask_seq(token_b1)
|
101 |
+
# else:
|
102 |
+
token_a = token_a[:self.seq_len-2]
|
103 |
+
sa_masked, sa_masked_label, sa_masked_pos = self.random_mask_seq(token_a)
|
104 |
+
|
105 |
+
s1 = ([self.vocab.vocab['[CLS]']] + sa_masked + [self.vocab.vocab['[SEP]']])
|
106 |
+
s1_label = ([self.vocab.vocab['[PAD]']] + sa_masked_label + [self.vocab.vocab['[PAD]']])
|
107 |
+
segment_label = [1 for _ in range(len(s1))]
|
108 |
+
masked_pos = ([0] + sa_masked_pos + [0])
|
109 |
+
|
110 |
+
# if self.select_next_seq:
|
111 |
+
# s1 = s1 + sb_masked + [self.vocab.vocab['[SEP]']]
|
112 |
+
# s1_label = s1_label + sb_masked_label + [self.vocab.vocab['[PAD]']]
|
113 |
+
# segment_label = segment_label + [2 for _ in range(len(sb_masked)+1)]
|
114 |
+
|
115 |
+
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
|
116 |
+
s1.extend(padding)
|
117 |
+
s1_label.extend(padding)
|
118 |
+
segment_label.extend(padding)
|
119 |
+
masked_pos.extend(padding)
|
120 |
+
|
121 |
+
output = {'bert_input': s1,
|
122 |
+
'bert_label': s1_label,
|
123 |
+
'segment_label': segment_label,
|
124 |
+
'masked_pos': masked_pos}
|
125 |
+
# print(f"tokenA: {token_a}")
|
126 |
+
# print(f"output: {output}")
|
127 |
+
|
128 |
+
# if self.select_next_seq:
|
129 |
+
# output['is_same_student'] = is_same_student
|
130 |
+
|
131 |
+
# print(item, len(s1), len(s1_label), len(segment_label))
|
132 |
+
# print(f"{item}.")
|
133 |
+
=======
|
134 |
token_b = None
|
135 |
is_same_student = None
|
136 |
sa_masked = None
|
|
|
167 |
if self.select_next_seq:
|
168 |
output['is_same_student'] = is_same_student
|
169 |
# print(item, len(s1), len(s1_label), len(segment_label))
|
170 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
171 |
return {key: torch.tensor(value) for key, value in output.items()}
|
172 |
|
173 |
def random_mask_seq(self, tokens):
|
|
|
176 |
Output: masked token seq, output label
|
177 |
"""
|
178 |
|
179 |
+
<<<<<<< HEAD
|
180 |
+
masked_pos = []
|
181 |
+
output_labels = []
|
182 |
+
output_tokens = copy.deepcopy(tokens)
|
183 |
+
opt_step = False
|
184 |
+
for i, token in enumerate(tokens):
|
185 |
+
if token in ['OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor', 'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow', 'ThirdRow']:
|
186 |
+
opt_step = True
|
187 |
+
# if opt_step:
|
188 |
+
# prob = random.random()
|
189 |
+
# if prob < self.max_mask:
|
190 |
+
# output_tokens[i] = random.choice([3,7,8,9,11,12,13,14,15,16,22,23,24,25,26,27,30,31,32])
|
191 |
+
# masked_pos.append(1)
|
192 |
+
# else:
|
193 |
+
# output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
194 |
+
# masked_pos.append(0)
|
195 |
+
# output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
|
196 |
+
# opt_step = False
|
197 |
+
# else:
|
198 |
+
prob = random.random()
|
199 |
+
if prob < self.max_mask:
|
200 |
+
=======
|
201 |
# masked_pos_label = {}
|
202 |
output_labels = []
|
203 |
output_tokens = copy.deepcopy(tokens)
|
|
|
206 |
for i, token in enumerate(tokens):
|
207 |
prob = random.random()
|
208 |
if prob < 0.15:
|
209 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
210 |
# chooses 15% of token positions at random
|
211 |
# prob /= 0.15
|
212 |
prob = random.random()
|
213 |
if prob < 0.8: #[MASK] token 80% of the time
|
214 |
output_tokens[i] = self.vocab.vocab['[MASK]']
|
215 |
+
<<<<<<< HEAD
|
216 |
+
masked_pos.append(1)
|
217 |
+
elif prob < 0.9: # a random token 10% of the time
|
218 |
+
# print(".......0.8-0.9......")
|
219 |
+
if opt_step:
|
220 |
+
output_tokens[i] = random.choice([7,8,9,11,12,13,14,15,16,22,23,24,25,26,27,30,31,32])
|
221 |
+
opt_step = False
|
222 |
+
else:
|
223 |
+
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
|
224 |
+
masked_pos.append(1)
|
225 |
+
else: # the unchanged i-th token 10% of the time
|
226 |
+
# print(".......unchanged......")
|
227 |
+
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
228 |
+
masked_pos.append(0)
|
229 |
+
=======
|
230 |
elif prob < 0.9: # a random token 10% of the time
|
231 |
# print(".......0.8-0.9......")
|
232 |
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
|
233 |
else: # the unchanged i-th token 10% of the time
|
234 |
# print(".......unchanged......")
|
235 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
236 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
237 |
# True Label
|
238 |
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
|
239 |
# masked_pos_label[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
|
|
242 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
243 |
# Padded label
|
244 |
output_labels.append(self.vocab.vocab['[PAD]'])
|
245 |
+
<<<<<<< HEAD
|
246 |
+
masked_pos.append(0)
|
247 |
+
=======
|
248 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
249 |
# label_position = []
|
250 |
# label_tokens = []
|
251 |
# for k, v in masked_pos_label.items():
|
252 |
# label_position.append(k)
|
253 |
# label_tokens.append(v)
|
254 |
+
<<<<<<< HEAD
|
255 |
+
return output_tokens, output_labels, masked_pos
|
256 |
+
|
257 |
+
# def get_token_b(self, item):
|
258 |
+
# document_id = [k for k,v in self.index_documents.items() if item in v][0]
|
259 |
+
# random_document_id = document_id
|
260 |
+
|
261 |
+
# if random.random() < 0.5:
|
262 |
+
# document_ids = [k for k in self.index_documents.keys() if k != document_id]
|
263 |
+
# random_document_id = random.choice(document_ids)
|
264 |
+
|
265 |
+
# same_student = (random_document_id == document_id)
|
266 |
+
|
267 |
+
# nex_seq_list = self.index_documents.get(random_document_id)
|
268 |
+
|
269 |
+
# if same_student:
|
270 |
+
# if len(nex_seq_list) != 1:
|
271 |
+
# nex_seq_list = [v for v in nex_seq_list if v !=item]
|
272 |
+
|
273 |
+
# next_seq = random.choice(nex_seq_list)
|
274 |
+
# tokens = self.lines[next_seq]
|
275 |
+
# # print(f"item = {item}, tokens: {tokens}")
|
276 |
+
# # print(f"item={item}, next={next_seq}, same_student = {same_student}, {document_id} == {random_document_id}, b. {tokens}")
|
277 |
+
# return same_student, tokens
|
278 |
+
|
279 |
+
# def truncate_to_max_seq(self, s1, s2):
|
280 |
+
# sa = copy.deepcopy(s1)
|
281 |
+
# sb = copy.deepcopy(s1)
|
282 |
+
# total_allowed_seq = self.seq_len - 3
|
283 |
+
|
284 |
+
# while((len(sa)+len(sb)) > total_allowed_seq):
|
285 |
+
# if random.random() < 0.5:
|
286 |
+
# sa.pop()
|
287 |
+
# else:
|
288 |
+
# sb.pop()
|
289 |
+
# return sa, sb
|
290 |
+
|
291 |
+
=======
|
292 |
return output_tokens, output_labels
|
293 |
|
294 |
def get_token_b(self, item):
|
|
|
324 |
else:
|
325 |
sb.pop()
|
326 |
return sa, sb
|
327 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
328 |
|
329 |
class TokenizerDataset(Dataset):
|
330 |
"""
|
|
|
332 |
Tokenize the data in the dataset
|
333 |
|
334 |
"""
|
335 |
+
<<<<<<< HEAD
|
336 |
+
def __init__(self, dataset_path, label_path, vocab, seq_len=30):
|
337 |
+
self.dataset_path = dataset_path
|
338 |
+
self.label_path = label_path
|
339 |
+
self.vocab = vocab # Vocab object
|
340 |
+
# self.encoder = OneHotEncoder(sparse=False)
|
341 |
+
=======
|
342 |
def __init__(self, dataset_path, label_path, vocab, seq_len=30, train=True):
|
343 |
self.dataset_path = dataset_path
|
344 |
self.label_path = label_path
|
345 |
self.vocab = vocab # Vocab object
|
346 |
self.encoder = OneHotEncoder(sparse_output=False)
|
347 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
348 |
|
349 |
# Related to input dataset file
|
350 |
self.lines = []
|
351 |
self.labels = []
|
352 |
+
<<<<<<< HEAD
|
353 |
+
self.feats = []
|
354 |
+
if self.label_path:
|
355 |
+
self.label_file = open(self.label_path, "r")
|
356 |
+
for line in self.label_file:
|
357 |
+
if line:
|
358 |
+
line = line.strip()
|
359 |
+
if not line:
|
360 |
+
continue
|
361 |
+
self.labels.append(int(line))
|
362 |
+
self.label_file.close()
|
363 |
+
|
364 |
+
# Comment this section if you are not using feat attribute
|
365 |
+
try:
|
366 |
+
j = 0
|
367 |
+
dataset_info_file = open(self.label_path.replace("label", "info"), "r")
|
368 |
+
for line in dataset_info_file:
|
369 |
+
if line:
|
370 |
+
line = line.strip()
|
371 |
+
if not line:
|
372 |
+
continue
|
373 |
+
|
374 |
+
# # highGRschool_w_prior
|
375 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
376 |
+
|
377 |
+
# highGRschool_w_prior_w_diffskill_wo_fa
|
378 |
+
feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
379 |
+
feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
380 |
+
feat_vec.extend(feat2[1:])
|
381 |
+
|
382 |
+
# # highGRschool_w_prior_w_p_diffskill_wo_fa
|
383 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
384 |
+
# feat2 = [-float(i) for i in line.split(",")[-2].split("\t")]
|
385 |
+
# feat_vec.extend(feat2[1:])
|
386 |
+
|
387 |
+
# # highGRschool_w_prior_w_diffskill_0fa_skill
|
388 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
389 |
+
# feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
390 |
+
# fa_feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
391 |
+
|
392 |
+
# diff_skill = [f2 if f1==0 else 0 for f2, f1 in zip(feat2, fa_feat_vec)]
|
393 |
+
# feat_vec.extend(diff_skill)
|
394 |
+
|
395 |
+
if j == 0:
|
396 |
+
print(len(feat_vec))
|
397 |
+
j+=1
|
398 |
+
|
399 |
+
# feat_vec.extend(feat2[1:])
|
400 |
+
# feat_vec.extend(feat2)
|
401 |
+
# feat_vec = [float(i) for i in line.split(",")[-2].split("\t")]
|
402 |
+
# feat_vec = feat_vec[1:]
|
403 |
+
# feat_vec = [float(line.split(",")[-1])]
|
404 |
+
# feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
405 |
+
# feat_vec = [ft-f1 for ft, f1 in zip(feat_vec, fa_feat_vec)]
|
406 |
+
|
407 |
+
self.feats.append(feat_vec)
|
408 |
+
dataset_info_file.close()
|
409 |
+
except Exception as e:
|
410 |
+
print(e)
|
411 |
+
# labeler = np.array([0, 1]) #np.unique(self.labels)
|
412 |
+
# print(f"Labeler {labeler}")
|
413 |
+
# self.encoder.fit(labeler.reshape(-1,1))
|
414 |
+
# self.labels = self.encoder.transform(np.array(self.labels).reshape(-1,1))
|
415 |
+
|
416 |
+
self.file = open(self.dataset_path, "r")
|
417 |
+
=======
|
418 |
self.labels = []
|
419 |
|
420 |
self.label_file = open(self.label_path, "r")
|
|
|
466 |
|
467 |
self.file = open(self.dataset_path, "r")
|
468 |
# index = 0
|
469 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
470 |
for line in self.file:
|
471 |
if line:
|
472 |
line = line.strip()
|
473 |
if line:
|
474 |
self.lines.append(line)
|
475 |
+
<<<<<<< HEAD
|
476 |
+
=======
|
477 |
# if train:
|
478 |
# if index in indices_of_zeros:
|
479 |
# # if index in indices_of_prom:
|
|
|
488 |
# self.labels.append(labels[index])
|
489 |
# self.labels.append(progress[index])
|
490 |
# index += 1
|
491 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
492 |
self.file.close()
|
493 |
|
494 |
self.len = len(self.lines)
|
495 |
self.seq_len = seq_len
|
496 |
+
<<<<<<< HEAD
|
497 |
+
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels) if self.label_path else 0)
|
498 |
+
=======
|
499 |
|
500 |
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels))
|
501 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
502 |
|
503 |
def __len__(self):
|
504 |
return self.len
|
505 |
|
506 |
def __getitem__(self, item):
|
507 |
+
<<<<<<< HEAD
|
508 |
+
org_line = self.lines[item].split("\t")
|
509 |
+
dup_line = []
|
510 |
+
opt = False
|
511 |
+
for l in org_line:
|
512 |
+
if l in ["OptionalTask_1", "EquationAnswer", "NumeratorFactor", "DenominatorFactor", "OptionalTask_2", "FirstRow1:1", "FirstRow1:2", "FirstRow2:1", "FirstRow2:2", "SecondRow", "ThirdRow"]:
|
513 |
+
opt = True
|
514 |
+
if opt and 'FinalAnswer-' in l:
|
515 |
+
dup_line.append('[UNK]')
|
516 |
+
else:
|
517 |
+
dup_line.append(l)
|
518 |
+
dup_line = "\t".join(dup_line)
|
519 |
+
# print(dup_line)
|
520 |
+
s1 = self.vocab.to_seq(dup_line, self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
521 |
+
s1_label = self.labels[item] if self.label_path else 0
|
522 |
+
segment_label = [1 for _ in range(len(s1))]
|
523 |
+
s1_feat = self.feats[item] if len(self.feats)>0 else 0
|
524 |
+
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
|
525 |
+
s1.extend(padding), segment_label.extend(padding)
|
526 |
+
|
527 |
+
output = {'input': s1,
|
528 |
+
'label': s1_label,
|
529 |
+
'feat': s1_feat,
|
530 |
+
=======
|
531 |
|
532 |
s1 = self.vocab.to_seq(self.lines[item], self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
533 |
s1_label = self.labels[item]
|
|
|
538 |
|
539 |
output = {'bert_input': s1,
|
540 |
'progress_status': s1_label,
|
541 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
542 |
'segment_label': segment_label}
|
543 |
return {key: torch.tensor(value) for key, value in output.items()}
|
544 |
|
545 |
|
546 |
+
<<<<<<< HEAD
|
547 |
+
class TokenizerDatasetForCalibration(Dataset):
|
548 |
+
"""
|
549 |
+
Class name: TokenizerDataset
|
550 |
+
Tokenize the data in the dataset
|
551 |
+
|
552 |
+
"""
|
553 |
+
def __init__(self, dataset_path, label_path, vocab, seq_len=30):
|
554 |
+
self.dataset_path = dataset_path
|
555 |
+
self.label_path = label_path
|
556 |
+
self.vocab = vocab # Vocab object
|
557 |
+
# self.encoder = OneHotEncoder(sparse=False)
|
558 |
+
|
559 |
+
# Related to input dataset file
|
560 |
+
self.lines = []
|
561 |
+
self.labels = []
|
562 |
+
self.feats = []
|
563 |
+
if self.label_path:
|
564 |
+
self.label_file = open(self.label_path, "r")
|
565 |
+
for line in self.label_file:
|
566 |
+
if line:
|
567 |
+
line = line.strip()
|
568 |
+
if not line:
|
569 |
+
continue
|
570 |
+
self.labels.append(int(line))
|
571 |
+
self.label_file.close()
|
572 |
+
|
573 |
+
# Comment this section if you are not using feat attribute
|
574 |
+
try:
|
575 |
+
j = 0
|
576 |
+
dataset_info_file = open(self.label_path.replace("label", "info"), "r")
|
577 |
+
for line in dataset_info_file:
|
578 |
+
if line:
|
579 |
+
line = line.strip()
|
580 |
+
if not line:
|
581 |
+
continue
|
582 |
+
|
583 |
+
# # highGRschool_w_prior
|
584 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
585 |
+
|
586 |
+
# highGRschool_w_prior_w_diffskill_wo_fa
|
587 |
+
feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
588 |
+
feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
589 |
+
feat_vec.extend(feat2[1:])
|
590 |
+
|
591 |
+
# # highGRschool_w_prior_w_diffskill_0fa_skill
|
592 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
593 |
+
# feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
594 |
+
# fa_feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
595 |
+
|
596 |
+
# diff_skill = [f2 if f1==0 else 0 for f2, f1 in zip(feat2, fa_feat_vec)]
|
597 |
+
# feat_vec.extend(diff_skill)
|
598 |
+
|
599 |
+
if j == 0:
|
600 |
+
print(len(feat_vec))
|
601 |
+
j+=1
|
602 |
+
|
603 |
+
# feat_vec.extend(feat2[1:])
|
604 |
+
# feat_vec.extend(feat2)
|
605 |
+
# feat_vec = [float(i) for i in line.split(",")[-2].split("\t")]
|
606 |
+
# feat_vec = feat_vec[1:]
|
607 |
+
# feat_vec = [float(line.split(",")[-1])]
|
608 |
+
# feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
609 |
+
# feat_vec = [ft-f1 for ft, f1 in zip(feat_vec, fa_feat_vec)]
|
610 |
+
|
611 |
+
self.feats.append(feat_vec)
|
612 |
+
dataset_info_file.close()
|
613 |
+
except Exception as e:
|
614 |
+
print(e)
|
615 |
+
# labeler = np.array([0, 1]) #np.unique(self.labels)
|
616 |
+
# print(f"Labeler {labeler}")
|
617 |
+
# self.encoder.fit(labeler.reshape(-1,1))
|
618 |
+
# self.labels = self.encoder.transform(np.array(self.labels).reshape(-1,1))
|
619 |
+
|
620 |
+
self.file = open(self.dataset_path, "r")
|
621 |
+
for line in self.file:
|
622 |
+
if line:
|
623 |
+
line = line.strip()
|
624 |
+
if line:
|
625 |
+
self.lines.append(line)
|
626 |
+
self.file.close()
|
627 |
+
|
628 |
+
self.len = len(self.lines)
|
629 |
+
self.seq_len = seq_len
|
630 |
+
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels) if self.label_path else 0)
|
631 |
+
|
632 |
+
def __len__(self):
|
633 |
+
return self.len
|
634 |
+
|
635 |
+
def __getitem__(self, item):
|
636 |
+
org_line = self.lines[item].split("\t")
|
637 |
+
dup_line = []
|
638 |
+
opt = False
|
639 |
+
for l in org_line:
|
640 |
+
if l in ["OptionalTask_1", "EquationAnswer", "NumeratorFactor", "DenominatorFactor", "OptionalTask_2", "FirstRow1:1", "FirstRow1:2", "FirstRow2:1", "FirstRow2:2", "SecondRow", "ThirdRow"]:
|
641 |
+
opt = True
|
642 |
+
if opt and 'FinalAnswer-' in l:
|
643 |
+
dup_line.append('[UNK]')
|
644 |
+
else:
|
645 |
+
dup_line.append(l)
|
646 |
+
dup_line = "\t".join(dup_line)
|
647 |
+
# print(dup_line)
|
648 |
+
s1 = self.vocab.to_seq(dup_line, self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
649 |
+
s1_label = self.labels[item] if self.label_path else 0
|
650 |
+
segment_label = [1 for _ in range(len(s1))]
|
651 |
+
s1_feat = self.feats[item] if len(self.feats)>0 else 0
|
652 |
+
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
|
653 |
+
s1.extend(padding), segment_label.extend(padding)
|
654 |
+
|
655 |
+
output = {'input': s1,
|
656 |
+
'label': s1_label,
|
657 |
+
'feat': s1_feat,
|
658 |
+
'segment_label': segment_label}
|
659 |
+
return ({key: torch.tensor(value) for key, value in output.items()}, s1_label)
|
660 |
+
|
661 |
+
|
662 |
+
|
663 |
+
# if __name__ == "__main__":
|
664 |
+
=======
|
665 |
# if __name__ == "__main__":
|
666 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
667 |
# # import pickle
|
668 |
# # k = pickle.load(open("dataset/CL4999_1920/unique_steps_list.pkl","rb"))
|
669 |
# # print(k)
|
src/pretrainer.py
CHANGED
@@ -1,5 +1,42 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from torch.nn import functional as F
|
4 |
from torch.optim import Adam, SGD
|
5 |
from torch.utils.data import DataLoader
|
@@ -67,6 +104,7 @@ class BERTTrainer:
|
|
67 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
68 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
|
69 |
workspace_name=None):
|
|
|
70 |
"""
|
71 |
:param bert: BERT model which you want to train
|
72 |
:param vocab_size: total word vocab size
|
@@ -79,6 +117,17 @@ class BERTTrainer:
|
|
79 |
:param log_freq: logging frequency of the batch iteration
|
80 |
"""
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
83 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
84 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
@@ -87,15 +136,24 @@ class BERTTrainer:
|
|
87 |
# This BERT model will be saved every epoch
|
88 |
self.bert = bert
|
89 |
# Initialize the BERT Language Model, with BERT model
|
|
|
90 |
self.model = BERTSM(bert, vocab_size).to(self.device)
|
91 |
|
92 |
# Distributed GPU training if CUDA can detect more than 1 GPU
|
93 |
if with_cuda and torch.cuda.device_count() > 1:
|
94 |
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
|
96 |
|
97 |
# Setting the train and test data loader
|
98 |
self.train_data = train_dataloader
|
|
|
99 |
self.test_data = test_dataloader
|
100 |
|
101 |
# Setting the Adam optimizer with hyper-param
|
@@ -106,19 +164,44 @@ class BERTTrainer:
|
|
106 |
self.criterion = nn.NLLLoss(ignore_index=0)
|
107 |
|
108 |
self.log_freq = log_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
self.same_student_prediction = same_student_prediction
|
110 |
self.workspace_name = workspace_name
|
111 |
self.save_model = False
|
112 |
self.avg_loss = 10000
|
|
|
113 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
114 |
|
115 |
def train(self, epoch):
|
116 |
self.iteration(epoch, self.train_data)
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
def test(self, epoch):
|
119 |
self.iteration(epoch, self.test_data, train=False)
|
120 |
|
121 |
def iteration(self, epoch, data_loader, train=True):
|
|
|
122 |
"""
|
123 |
loop over the data_loader for training or testing
|
124 |
if on train status, backward operation is activated
|
@@ -129,6 +212,30 @@ class BERTTrainer:
|
|
129 |
:param train: boolean value of is train or test
|
130 |
:return: None
|
131 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
str_code = "train" if train else "test"
|
133 |
code = "masked_prediction" if self.same_student_prediction else "masked"
|
134 |
|
@@ -155,10 +262,25 @@ class BERTTrainer:
|
|
155 |
|
156 |
avg_loss = 0.0
|
157 |
with open(self.log_file, 'a') as f:
|
|
|
158 |
sys.stdout = f
|
159 |
for i, data in data_iter:
|
160 |
# 0. batch_data will be sent into the device(GPU or cpu)
|
161 |
data = {key: value.to(self.device) for key, value in data.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
# 1. forward the next_sentence_prediction and masked_lm model
|
164 |
# next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
|
@@ -184,10 +306,49 @@ class BERTTrainer:
|
|
184 |
|
185 |
# 3. backward and optimization only in train
|
186 |
if train:
|
|
|
187 |
self.optim_schedule.zero_grad()
|
188 |
loss.backward()
|
189 |
self.optim_schedule.step_and_update_lr()
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
non_zero_mask = (data["bert_label"] != 0).float()
|
193 |
predictions = torch.argmax(mask_lm_output, dim=-1)
|
@@ -249,6 +410,7 @@ class BERTTrainer:
|
|
249 |
# pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))
|
250 |
|
251 |
|
|
|
252 |
|
253 |
def save(self, epoch, file_path="output/bert_trained.model"):
|
254 |
"""
|
@@ -270,7 +432,12 @@ class BERTFineTuneTrainer:
|
|
270 |
def __init__(self, bert: BERT, vocab_size: int,
|
271 |
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
272 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
|
|
|
|
|
|
|
|
273 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2):
|
|
|
274 |
"""
|
275 |
:param bert: BERT model which you want to train
|
276 |
:param vocab_size: total word vocab size
|
@@ -286,6 +453,302 @@ class BERTFineTuneTrainer:
|
|
286 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
287 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
288 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
print("Device used = ", self.device)
|
290 |
|
291 |
# This BERT model will be saved every epoch
|
@@ -320,15 +783,28 @@ class BERTFineTuneTrainer:
|
|
320 |
self.workspace_name = workspace_name
|
321 |
self.save_model = False
|
322 |
self.avg_loss = 10000
|
|
|
323 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
324 |
|
325 |
def train(self, epoch):
|
326 |
self.iteration(epoch, self.train_data)
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
def test(self, epoch):
|
329 |
self.iteration(epoch, self.test_data, train=False)
|
330 |
|
331 |
def iteration(self, epoch, data_loader, train=True):
|
|
|
332 |
"""
|
333 |
loop over the data_loader for training or testing
|
334 |
if on train status, backward operation is activated
|
@@ -339,6 +815,12 @@ class BERTFineTuneTrainer:
|
|
339 |
:param train: boolean value of is train or test
|
340 |
:return: None
|
341 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
str_code = "train" if train else "test"
|
343 |
|
344 |
self.log_file = f"{self.workspace_name}/logs/masked/log_{str_code}_FS_finetuned.txt"
|
@@ -352,6 +834,7 @@ class BERTFineTuneTrainer:
|
|
352 |
# Setting the tqdm progress bar
|
353 |
data_iter = tqdm.tqdm(enumerate(data_loader),
|
354 |
desc="EP_%s:%d" % (str_code, epoch),
|
|
|
355 |
total=len(data_loader),
|
356 |
bar_format="{l_bar}{r_bar}")
|
357 |
|
@@ -360,6 +843,28 @@ class BERTFineTuneTrainer:
|
|
360 |
total_element = 0
|
361 |
plabels = []
|
362 |
tlabels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
eval_accurate_nb = 0
|
364 |
nb_eval_examples = 0
|
365 |
logits_list = []
|
@@ -390,10 +895,81 @@ class BERTFineTuneTrainer:
|
|
390 |
progress_loss = self.criterion(logits, data["progress_status"])
|
391 |
loss = progress_loss
|
392 |
|
|
|
393 |
if torch.cuda.device_count() > 1:
|
394 |
loss = loss.mean()
|
395 |
|
396 |
# 3. backward and optimization only in train
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
if train:
|
398 |
self.optim.zero_grad()
|
399 |
loss.backward()
|
@@ -489,13 +1065,40 @@ class BERTFineTuneTrainer:
|
|
489 |
f.close()
|
490 |
sys.stdout = sys.__stdout__
|
491 |
if train:
|
|
|
492 |
self.save_model = False
|
493 |
if self.avg_loss > (avg_loss / len(data_iter)):
|
494 |
self.save_model = True
|
495 |
self.avg_loss = (avg_loss / len(data_iter))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
|
497 |
# plt_test.show()
|
498 |
# print("EP%d_%s, " % (epoch, str_code))
|
|
|
499 |
|
500 |
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
501 |
"""
|
@@ -510,3 +1113,113 @@ class BERTFineTuneTrainer:
|
|
510 |
self.model.to(self.device)
|
511 |
print("EP:%d Model Saved on:" % epoch, output_path)
|
512 |
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
<<<<<<< HEAD
|
4 |
+
# from torch.nn import functional as F
|
5 |
+
from torch.optim import Adam
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
# import pickle
|
8 |
+
|
9 |
+
from .bert import BERT
|
10 |
+
from .seq_model import BERTSM
|
11 |
+
from .classifier_model import BERTForClassification, BERTForClassificationWithFeats
|
12 |
+
from .optim_schedule import ScheduledOptim
|
13 |
+
|
14 |
+
import tqdm
|
15 |
+
import sys
|
16 |
+
import time
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
|
21 |
+
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
import seaborn as sns
|
24 |
+
import pandas as pd
|
25 |
+
from collections import defaultdict
|
26 |
+
import os
|
27 |
+
|
28 |
+
class BERTTrainer:
|
29 |
+
"""
|
30 |
+
BERTTrainer pretrains BERT model on input sequence of strategies.
|
31 |
+
BERTTrainer make the pretrained BERT model with one training method objective.
|
32 |
+
1. Masked Strategy Modeling :Masked SM
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
36 |
+
train_dataloader: DataLoader, val_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
|
37 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=5000,
|
38 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, log_folder_path: str = None):
|
39 |
+
=======
|
40 |
from torch.nn import functional as F
|
41 |
from torch.optim import Adam, SGD
|
42 |
from torch.utils.data import DataLoader
|
|
|
104 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
105 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
|
106 |
workspace_name=None):
|
107 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
108 |
"""
|
109 |
:param bert: BERT model which you want to train
|
110 |
:param vocab_size: total word vocab size
|
|
|
117 |
:param log_freq: logging frequency of the batch iteration
|
118 |
"""
|
119 |
|
120 |
+
<<<<<<< HEAD
|
121 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
122 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
123 |
+
print(cuda_condition, " Device used = ", self.device)
|
124 |
+
|
125 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
126 |
+
|
127 |
+
# This BERT model will be saved
|
128 |
+
self.bert = bert.to(self.device)
|
129 |
+
# Initialize the BERT Sequence Model, with BERT model
|
130 |
+
=======
|
131 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
132 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
133 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
|
|
136 |
# This BERT model will be saved every epoch
|
137 |
self.bert = bert
|
138 |
# Initialize the BERT Language Model, with BERT model
|
139 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
140 |
self.model = BERTSM(bert, vocab_size).to(self.device)
|
141 |
|
142 |
# Distributed GPU training if CUDA can detect more than 1 GPU
|
143 |
if with_cuda and torch.cuda.device_count() > 1:
|
144 |
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
145 |
+
<<<<<<< HEAD
|
146 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
147 |
+
|
148 |
+
# Setting the train, validation and test data loader
|
149 |
+
self.train_data = train_dataloader
|
150 |
+
self.val_data = val_dataloader
|
151 |
+
=======
|
152 |
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
|
153 |
|
154 |
# Setting the train and test data loader
|
155 |
self.train_data = train_dataloader
|
156 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
157 |
self.test_data = test_dataloader
|
158 |
|
159 |
# Setting the Adam optimizer with hyper-param
|
|
|
164 |
self.criterion = nn.NLLLoss(ignore_index=0)
|
165 |
|
166 |
self.log_freq = log_freq
|
167 |
+
<<<<<<< HEAD
|
168 |
+
self.log_folder_path = log_folder_path
|
169 |
+
# self.workspace_name = workspace_name
|
170 |
+
self.save_model = False
|
171 |
+
# self.code = code
|
172 |
+
self.avg_loss = 10000
|
173 |
+
for fi in ['train', 'val', 'test']:
|
174 |
+
f = open(self.log_folder_path+f"/log_{fi}_pretrained.txt", 'w')
|
175 |
+
f.close()
|
176 |
+
self.start_time = time.time()
|
177 |
+
|
178 |
+
=======
|
179 |
self.same_student_prediction = same_student_prediction
|
180 |
self.workspace_name = workspace_name
|
181 |
self.save_model = False
|
182 |
self.avg_loss = 10000
|
183 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
184 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
185 |
|
186 |
def train(self, epoch):
|
187 |
self.iteration(epoch, self.train_data)
|
188 |
|
189 |
+
<<<<<<< HEAD
|
190 |
+
def val(self, epoch):
|
191 |
+
if epoch == 0:
|
192 |
+
self.avg_loss = 10000
|
193 |
+
self.iteration(epoch, self.val_data, phase="val")
|
194 |
+
|
195 |
+
def test(self, epoch):
|
196 |
+
self.iteration(epoch, self.test_data, phase="test")
|
197 |
+
|
198 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
199 |
+
=======
|
200 |
def test(self, epoch):
|
201 |
self.iteration(epoch, self.test_data, train=False)
|
202 |
|
203 |
def iteration(self, epoch, data_loader, train=True):
|
204 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
205 |
"""
|
206 |
loop over the data_loader for training or testing
|
207 |
if on train status, backward operation is activated
|
|
|
212 |
:param train: boolean value of is train or test
|
213 |
:return: None
|
214 |
"""
|
215 |
+
<<<<<<< HEAD
|
216 |
+
|
217 |
+
# self.log_file = f"{self.workspace_name}/logs/{self.code}/log_{phase}_pretrained.txt"
|
218 |
+
# bert_hidden_representations = [] can be used
|
219 |
+
# if epoch == 0:
|
220 |
+
# f = open(self.log_file, 'w')
|
221 |
+
# f.close()
|
222 |
+
|
223 |
+
# Progress bar
|
224 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
225 |
+
desc="EP_%s:%d" % (phase, epoch),
|
226 |
+
total=len(data_loader),
|
227 |
+
bar_format="{l_bar}{r_bar}")
|
228 |
+
|
229 |
+
total_correct = 0
|
230 |
+
total_element = 0
|
231 |
+
avg_loss = 0.0
|
232 |
+
|
233 |
+
if phase == "train":
|
234 |
+
self.model.train()
|
235 |
+
else:
|
236 |
+
self.model.eval()
|
237 |
+
with open(self.log_folder_path+f"/log_{phase}_pretrained.txt", 'a') as f:
|
238 |
+
=======
|
239 |
str_code = "train" if train else "test"
|
240 |
code = "masked_prediction" if self.same_student_prediction else "masked"
|
241 |
|
|
|
262 |
|
263 |
avg_loss = 0.0
|
264 |
with open(self.log_file, 'a') as f:
|
265 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
266 |
sys.stdout = f
|
267 |
for i, data in data_iter:
|
268 |
# 0. batch_data will be sent into the device(GPU or cpu)
|
269 |
data = {key: value.to(self.device) for key, value in data.items()}
|
270 |
+
<<<<<<< HEAD
|
271 |
+
|
272 |
+
# 1. forward masked_sm model
|
273 |
+
# mask_sm_output is log-probabilities output
|
274 |
+
mask_sm_output, bert_hidden_rep = self.model.forward(data["bert_input"], data["segment_label"])
|
275 |
+
|
276 |
+
# 2. NLLLoss of predicting masked token word
|
277 |
+
loss = self.criterion(mask_sm_output.transpose(1, 2), data["bert_label"])
|
278 |
+
if torch.cuda.device_count() > 1:
|
279 |
+
loss = loss.mean()
|
280 |
+
|
281 |
+
# 3. backward and optimization only in train
|
282 |
+
if phase == "train":
|
283 |
+
=======
|
284 |
|
285 |
# 1. forward the next_sentence_prediction and masked_lm model
|
286 |
# next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
|
|
|
306 |
|
307 |
# 3. backward and optimization only in train
|
308 |
if train:
|
309 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
310 |
self.optim_schedule.zero_grad()
|
311 |
loss.backward()
|
312 |
self.optim_schedule.step_and_update_lr()
|
313 |
|
314 |
+
<<<<<<< HEAD
|
315 |
+
# tokens with highest log-probabilities creates a predicted sequence
|
316 |
+
pred_tokens = torch.argmax(mask_sm_output, dim=-1)
|
317 |
+
mask_correct = (data["bert_label"] == pred_tokens) & data["masked_pos"]
|
318 |
+
|
319 |
+
total_correct += mask_correct.sum().item()
|
320 |
+
total_element += data["masked_pos"].sum().item()
|
321 |
+
avg_loss +=loss.item()
|
322 |
+
|
323 |
+
torch.cuda.empty_cache()
|
324 |
+
|
325 |
+
post_fix = {
|
326 |
+
"epoch": epoch,
|
327 |
+
"iter": i,
|
328 |
+
"avg_loss": avg_loss / (i + 1),
|
329 |
+
"avg_acc_mask": (total_correct / total_element * 100) if total_element != 0 else 0,
|
330 |
+
"loss": loss.item()
|
331 |
+
}
|
332 |
+
if i % self.log_freq == 0:
|
333 |
+
data_iter.write(str(post_fix))
|
334 |
+
|
335 |
+
end_time = time.time()
|
336 |
+
final_msg = {
|
337 |
+
"epoch": f"EP{epoch}_{phase}",
|
338 |
+
"avg_loss": avg_loss / len(data_iter),
|
339 |
+
"total_masked_acc": (total_correct / total_element * 100) if total_element != 0 else 0,
|
340 |
+
"time_taken_from_start": end_time - self.start_time
|
341 |
+
}
|
342 |
+
print(final_msg)
|
343 |
+
f.close()
|
344 |
+
sys.stdout = sys.__stdout__
|
345 |
+
|
346 |
+
if phase == "val":
|
347 |
+
self.save_model = False
|
348 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
349 |
+
self.save_model = True
|
350 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
351 |
+
=======
|
352 |
|
353 |
non_zero_mask = (data["bert_label"] != 0).float()
|
354 |
predictions = torch.argmax(mask_lm_output, dim=-1)
|
|
|
410 |
# pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))
|
411 |
|
412 |
|
413 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
414 |
|
415 |
def save(self, epoch, file_path="output/bert_trained.model"):
|
416 |
"""
|
|
|
432 |
def __init__(self, bert: BERT, vocab_size: int,
|
433 |
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
434 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
435 |
+
<<<<<<< HEAD
|
436 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
437 |
+
num_labels=2, log_folder_path: str = None):
|
438 |
+
=======
|
439 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2):
|
440 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
441 |
"""
|
442 |
:param bert: BERT model which you want to train
|
443 |
:param vocab_size: total word vocab size
|
|
|
453 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
454 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
455 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
456 |
+
<<<<<<< HEAD
|
457 |
+
print(cuda_condition, " Device used = ", self.device)
|
458 |
+
|
459 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
460 |
+
|
461 |
+
# This BERT model will be saved every epoch
|
462 |
+
self.bert = bert
|
463 |
+
for param in self.bert.parameters():
|
464 |
+
param.requires_grad = False
|
465 |
+
# Initialize the BERT Language Model, with BERT model
|
466 |
+
# self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
467 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
468 |
+
self.model = BERTForClassificationWithFeats(self.bert, num_labels, 17).to(self.device)
|
469 |
+
|
470 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
471 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
472 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
473 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
474 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
475 |
+
|
476 |
+
# Setting the train, validation and test data loader
|
477 |
+
self.train_data = train_dataloader
|
478 |
+
# self.val_data = val_dataloader
|
479 |
+
self.test_data = test_dataloader
|
480 |
+
|
481 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
482 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
483 |
+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
|
484 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
485 |
+
self.criterion = nn.CrossEntropyLoss()
|
486 |
+
|
487 |
+
# if num_labels == 1:
|
488 |
+
# self.criterion = nn.MSELoss()
|
489 |
+
# elif num_labels == 2:
|
490 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
491 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
492 |
+
# elif num_labels > 2:
|
493 |
+
# self.criterion = nn.CrossEntropyLoss()
|
494 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
495 |
+
|
496 |
+
|
497 |
+
self.log_freq = log_freq
|
498 |
+
self.log_folder_path = log_folder_path
|
499 |
+
# self.workspace_name = workspace_name
|
500 |
+
# self.finetune_task = finetune_task
|
501 |
+
self.save_model = False
|
502 |
+
self.avg_loss = 10000
|
503 |
+
self.start_time = time.time()
|
504 |
+
# self.probability_list = []
|
505 |
+
for fi in ['train', 'test']: #'val',
|
506 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
507 |
+
f.close()
|
508 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
509 |
+
|
510 |
+
def train(self, epoch):
|
511 |
+
self.iteration(epoch, self.train_data)
|
512 |
+
|
513 |
+
# def val(self, epoch):
|
514 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
515 |
+
|
516 |
+
def test(self, epoch):
|
517 |
+
if epoch == 0:
|
518 |
+
self.avg_loss = 10000
|
519 |
+
self.iteration(epoch, self.test_data, phase="test")
|
520 |
+
|
521 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
522 |
+
"""
|
523 |
+
loop over the data_loader for training or testing
|
524 |
+
if on train status, backward operation is activated
|
525 |
+
and also auto save the model every peoch
|
526 |
+
|
527 |
+
:param epoch: current epoch index
|
528 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
529 |
+
:param train: boolean value of is train or test
|
530 |
+
:return: None
|
531 |
+
"""
|
532 |
+
|
533 |
+
# Setting the tqdm progress bar
|
534 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
535 |
+
desc="EP_%s:%d" % (phase, epoch),
|
536 |
+
total=len(data_loader),
|
537 |
+
bar_format="{l_bar}{r_bar}")
|
538 |
+
|
539 |
+
avg_loss = 0.0
|
540 |
+
total_correct = 0
|
541 |
+
total_element = 0
|
542 |
+
plabels = []
|
543 |
+
tlabels = []
|
544 |
+
probabs = []
|
545 |
+
|
546 |
+
if phase == "train":
|
547 |
+
self.model.train()
|
548 |
+
else:
|
549 |
+
self.model.eval()
|
550 |
+
# self.probability_list = []
|
551 |
+
|
552 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
553 |
+
sys.stdout = f
|
554 |
+
for i, data in data_iter:
|
555 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
556 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
557 |
+
if phase == "train":
|
558 |
+
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
559 |
+
else:
|
560 |
+
with torch.no_grad():
|
561 |
+
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
562 |
+
|
563 |
+
loss = self.criterion(logits, data["label"])
|
564 |
+
if torch.cuda.device_count() > 1:
|
565 |
+
loss = loss.mean()
|
566 |
+
|
567 |
+
# 3. backward and optimization only in train
|
568 |
+
if phase == "train":
|
569 |
+
self.optim_schedule.zero_grad()
|
570 |
+
loss.backward()
|
571 |
+
self.optim_schedule.step_and_update_lr()
|
572 |
+
|
573 |
+
# prediction accuracy
|
574 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
575 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
576 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
577 |
+
# self.probability_list.append(probs)
|
578 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
579 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
580 |
+
tlabels.extend(data['label'].cpu().numpy())
|
581 |
+
|
582 |
+
# Compare predicted labels to true labels and calculate accuracy
|
583 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
584 |
+
|
585 |
+
avg_loss += loss.item()
|
586 |
+
total_correct += correct
|
587 |
+
# total_element += true_labels.nelement()
|
588 |
+
total_element += data["label"].nelement()
|
589 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
590 |
+
|
591 |
+
post_fix = {
|
592 |
+
"epoch": epoch,
|
593 |
+
"iter": i,
|
594 |
+
"avg_loss": avg_loss / (i + 1),
|
595 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
596 |
+
"loss": loss.item()
|
597 |
+
}
|
598 |
+
if i % self.log_freq == 0:
|
599 |
+
data_iter.write(str(post_fix))
|
600 |
+
|
601 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
602 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
603 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
604 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
605 |
+
end_time = time.time()
|
606 |
+
final_msg = {
|
607 |
+
"epoch": f"EP{epoch}_{phase}",
|
608 |
+
"avg_loss": avg_loss / len(data_iter),
|
609 |
+
"total_acc": total_correct * 100.0 / total_element,
|
610 |
+
"precisions": precisions,
|
611 |
+
"recalls": recalls,
|
612 |
+
"f1_scores": f1_scores,
|
613 |
+
# "confusion_matrix": f"{cmatrix}",
|
614 |
+
# "true_labels": f"{tlabels}",
|
615 |
+
# "predicted_labels": f"{plabels}",
|
616 |
+
"time_taken_from_start": end_time - self.start_time
|
617 |
+
}
|
618 |
+
print(final_msg)
|
619 |
+
f.close()
|
620 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
621 |
+
sys.stdout = f1
|
622 |
+
final_msg = {
|
623 |
+
"epoch": f"EP{epoch}_{phase}",
|
624 |
+
"confusion_matrix": f"{cmatrix}",
|
625 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
626 |
+
"predicted_labels": f"{plabels}",
|
627 |
+
"probabilities": f"{probabs}",
|
628 |
+
"time_taken_from_start": end_time - self.start_time
|
629 |
+
}
|
630 |
+
print(final_msg)
|
631 |
+
f1.close()
|
632 |
+
sys.stdout = sys.__stdout__
|
633 |
+
sys.stdout = sys.__stdout__
|
634 |
+
|
635 |
+
if phase == "test":
|
636 |
+
self.save_model = False
|
637 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
638 |
+
self.save_model = True
|
639 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
640 |
+
|
641 |
+
def iteration_1(self, epoch_idx, data):
|
642 |
+
try:
|
643 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
644 |
+
logits = self.model(data['input_ids'], data['segment_label'])
|
645 |
+
# Ensure logits is a tensor, not a tuple
|
646 |
+
loss_fct = nn.CrossEntropyLoss()
|
647 |
+
loss = loss_fct(logits, data['labels'])
|
648 |
+
|
649 |
+
# Backpropagation and optimization
|
650 |
+
self.optim.zero_grad()
|
651 |
+
loss.backward()
|
652 |
+
self.optim.step()
|
653 |
+
|
654 |
+
if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
|
655 |
+
print(f"Epoch {epoch_idx}: Loss = {loss.item()}")
|
656 |
+
|
657 |
+
return loss
|
658 |
+
|
659 |
+
except Exception as e:
|
660 |
+
print(f"Error during iteration: {e}")
|
661 |
+
raise
|
662 |
+
|
663 |
+
|
664 |
+
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
665 |
+
"""
|
666 |
+
Saving the current BERT model on file_path
|
667 |
+
|
668 |
+
:param epoch: current epoch number
|
669 |
+
:param file_path: model output path which gonna be file_path+"ep%d" % epoch
|
670 |
+
:return: final_output_path
|
671 |
+
"""
|
672 |
+
output_path = file_path + ".ep%d" % epoch
|
673 |
+
torch.save(self.model.cpu(), output_path)
|
674 |
+
self.model.to(self.device)
|
675 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
676 |
+
return output_path
|
677 |
+
|
678 |
+
class BERTFineTuneTrainer1:
|
679 |
+
|
680 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
681 |
+
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
682 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
683 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
684 |
+
num_labels=2, log_folder_path: str = None):
|
685 |
+
"""
|
686 |
+
:param bert: BERT model which you want to train
|
687 |
+
:param vocab_size: total word vocab size
|
688 |
+
:param train_dataloader: train dataset data loader
|
689 |
+
:param test_dataloader: test dataset data loader [can be None]
|
690 |
+
:param lr: learning rate of optimizer
|
691 |
+
:param betas: Adam optimizer betas
|
692 |
+
:param weight_decay: Adam optimizer weight decay param
|
693 |
+
:param with_cuda: traning with cuda
|
694 |
+
:param log_freq: logging frequency of the batch iteration
|
695 |
+
"""
|
696 |
+
|
697 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
698 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
699 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
700 |
+
print(cuda_condition, " Device used = ", self.device)
|
701 |
+
|
702 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
703 |
+
|
704 |
+
# This BERT model will be saved every epoch
|
705 |
+
self.bert = bert
|
706 |
+
for param in self.bert.parameters():
|
707 |
+
param.requires_grad = False
|
708 |
+
# Initialize the BERT Language Model, with BERT model
|
709 |
+
self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
710 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
711 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8*2).to(self.device)
|
712 |
+
|
713 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
714 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
715 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
716 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
717 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
718 |
+
|
719 |
+
# Setting the train, validation and test data loader
|
720 |
+
self.train_data = train_dataloader
|
721 |
+
# self.val_data = val_dataloader
|
722 |
+
self.test_data = test_dataloader
|
723 |
+
|
724 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
725 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
726 |
+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
|
727 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
728 |
+
self.criterion = nn.CrossEntropyLoss()
|
729 |
+
|
730 |
+
# if num_labels == 1:
|
731 |
+
# self.criterion = nn.MSELoss()
|
732 |
+
# elif num_labels == 2:
|
733 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
734 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
735 |
+
# elif num_labels > 2:
|
736 |
+
# self.criterion = nn.CrossEntropyLoss()
|
737 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
738 |
+
|
739 |
+
|
740 |
+
self.log_freq = log_freq
|
741 |
+
self.log_folder_path = log_folder_path
|
742 |
+
# self.workspace_name = workspace_name
|
743 |
+
# self.finetune_task = finetune_task
|
744 |
+
self.save_model = False
|
745 |
+
self.avg_loss = 10000
|
746 |
+
self.start_time = time.time()
|
747 |
+
# self.probability_list = []
|
748 |
+
for fi in ['train', 'test']: #'val',
|
749 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
750 |
+
f.close()
|
751 |
+
=======
|
752 |
print("Device used = ", self.device)
|
753 |
|
754 |
# This BERT model will be saved every epoch
|
|
|
783 |
self.workspace_name = workspace_name
|
784 |
self.save_model = False
|
785 |
self.avg_loss = 10000
|
786 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
787 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
788 |
|
789 |
def train(self, epoch):
|
790 |
self.iteration(epoch, self.train_data)
|
791 |
|
792 |
+
<<<<<<< HEAD
|
793 |
+
# def val(self, epoch):
|
794 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
795 |
+
|
796 |
+
def test(self, epoch):
|
797 |
+
if epoch == 0:
|
798 |
+
self.avg_loss = 10000
|
799 |
+
self.iteration(epoch, self.test_data, phase="test")
|
800 |
+
|
801 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
802 |
+
=======
|
803 |
def test(self, epoch):
|
804 |
self.iteration(epoch, self.test_data, train=False)
|
805 |
|
806 |
def iteration(self, epoch, data_loader, train=True):
|
807 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
808 |
"""
|
809 |
loop over the data_loader for training or testing
|
810 |
if on train status, backward operation is activated
|
|
|
815 |
:param train: boolean value of is train or test
|
816 |
:return: None
|
817 |
"""
|
818 |
+
<<<<<<< HEAD
|
819 |
+
|
820 |
+
# Setting the tqdm progress bar
|
821 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
822 |
+
desc="EP_%s:%d" % (phase, epoch),
|
823 |
+
=======
|
824 |
str_code = "train" if train else "test"
|
825 |
|
826 |
self.log_file = f"{self.workspace_name}/logs/masked/log_{str_code}_FS_finetuned.txt"
|
|
|
834 |
# Setting the tqdm progress bar
|
835 |
data_iter = tqdm.tqdm(enumerate(data_loader),
|
836 |
desc="EP_%s:%d" % (str_code, epoch),
|
837 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
838 |
total=len(data_loader),
|
839 |
bar_format="{l_bar}{r_bar}")
|
840 |
|
|
|
843 |
total_element = 0
|
844 |
plabels = []
|
845 |
tlabels = []
|
846 |
+
<<<<<<< HEAD
|
847 |
+
probabs = []
|
848 |
+
|
849 |
+
if phase == "train":
|
850 |
+
self.model.train()
|
851 |
+
else:
|
852 |
+
self.model.eval()
|
853 |
+
# self.probability_list = []
|
854 |
+
|
855 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
856 |
+
sys.stdout = f
|
857 |
+
for i, data in data_iter:
|
858 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
859 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
860 |
+
if phase == "train":
|
861 |
+
logits = self.model.forward(data["input"], data["segment_label"])#, data["feat"])
|
862 |
+
else:
|
863 |
+
with torch.no_grad():
|
864 |
+
logits = self.model.forward(data["input"], data["segment_label"])#, data["feat"])
|
865 |
+
|
866 |
+
loss = self.criterion(logits, data["label"])
|
867 |
+
=======
|
868 |
eval_accurate_nb = 0
|
869 |
nb_eval_examples = 0
|
870 |
logits_list = []
|
|
|
895 |
progress_loss = self.criterion(logits, data["progress_status"])
|
896 |
loss = progress_loss
|
897 |
|
898 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
899 |
if torch.cuda.device_count() > 1:
|
900 |
loss = loss.mean()
|
901 |
|
902 |
# 3. backward and optimization only in train
|
903 |
+
<<<<<<< HEAD
|
904 |
+
if phase == "train":
|
905 |
+
self.optim_schedule.zero_grad()
|
906 |
+
loss.backward()
|
907 |
+
self.optim_schedule.step_and_update_lr()
|
908 |
+
|
909 |
+
# prediction accuracy
|
910 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
911 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
912 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
913 |
+
# self.probability_list.append(probs)
|
914 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
915 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
916 |
+
tlabels.extend(data['label'].cpu().numpy())
|
917 |
+
|
918 |
+
# Compare predicted labels to true labels and calculate accuracy
|
919 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
920 |
+
|
921 |
+
avg_loss += loss.item()
|
922 |
+
total_correct += correct
|
923 |
+
# total_element += true_labels.nelement()
|
924 |
+
total_element += data["label"].nelement()
|
925 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
926 |
+
|
927 |
+
post_fix = {
|
928 |
+
"epoch": epoch,
|
929 |
+
"iter": i,
|
930 |
+
"avg_loss": avg_loss / (i + 1),
|
931 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
932 |
+
"loss": loss.item()
|
933 |
+
}
|
934 |
+
if i % self.log_freq == 0:
|
935 |
+
data_iter.write(str(post_fix))
|
936 |
+
|
937 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
938 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
939 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
940 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
941 |
+
end_time = time.time()
|
942 |
+
final_msg = {
|
943 |
+
"epoch": f"EP{epoch}_{phase}",
|
944 |
+
"avg_loss": avg_loss / len(data_iter),
|
945 |
+
"total_acc": total_correct * 100.0 / total_element,
|
946 |
+
"precisions": precisions,
|
947 |
+
"recalls": recalls,
|
948 |
+
"f1_scores": f1_scores,
|
949 |
+
# "confusion_matrix": f"{cmatrix}",
|
950 |
+
# "true_labels": f"{tlabels}",
|
951 |
+
# "predicted_labels": f"{plabels}",
|
952 |
+
"time_taken_from_start": end_time - self.start_time
|
953 |
+
}
|
954 |
+
print(final_msg)
|
955 |
+
f.close()
|
956 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
957 |
+
sys.stdout = f1
|
958 |
+
final_msg = {
|
959 |
+
"epoch": f"EP{epoch}_{phase}",
|
960 |
+
"confusion_matrix": f"{cmatrix}",
|
961 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
962 |
+
"predicted_labels": f"{plabels}",
|
963 |
+
"probabilities": f"{probabs}",
|
964 |
+
"time_taken_from_start": end_time - self.start_time
|
965 |
+
}
|
966 |
+
print(final_msg)
|
967 |
+
f1.close()
|
968 |
+
sys.stdout = sys.__stdout__
|
969 |
+
sys.stdout = sys.__stdout__
|
970 |
+
|
971 |
+
if phase == "test":
|
972 |
+
=======
|
973 |
if train:
|
974 |
self.optim.zero_grad()
|
975 |
loss.backward()
|
|
|
1065 |
f.close()
|
1066 |
sys.stdout = sys.__stdout__
|
1067 |
if train:
|
1068 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
1069 |
self.save_model = False
|
1070 |
if self.avg_loss > (avg_loss / len(data_iter)):
|
1071 |
self.save_model = True
|
1072 |
self.avg_loss = (avg_loss / len(data_iter))
|
1073 |
+
<<<<<<< HEAD
|
1074 |
+
|
1075 |
+
def iteration_1(self, epoch_idx, data):
|
1076 |
+
try:
|
1077 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
1078 |
+
logits = self.model(data['input_ids'], data['segment_label'])
|
1079 |
+
# Ensure logits is a tensor, not a tuple
|
1080 |
+
loss_fct = nn.CrossEntropyLoss()
|
1081 |
+
loss = loss_fct(logits, data['labels'])
|
1082 |
+
|
1083 |
+
# Backpropagation and optimization
|
1084 |
+
self.optim.zero_grad()
|
1085 |
+
loss.backward()
|
1086 |
+
self.optim.step()
|
1087 |
+
|
1088 |
+
if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
|
1089 |
+
print(f"Epoch {epoch_idx}: Loss = {loss.item()}")
|
1090 |
+
|
1091 |
+
return loss
|
1092 |
+
|
1093 |
+
except Exception as e:
|
1094 |
+
print(f"Error during iteration: {e}")
|
1095 |
+
raise
|
1096 |
+
|
1097 |
+
=======
|
1098 |
|
1099 |
# plt_test.show()
|
1100 |
# print("EP%d_%s, " % (epoch, str_code))
|
1101 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
1102 |
|
1103 |
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
1104 |
"""
|
|
|
1113 |
self.model.to(self.device)
|
1114 |
print("EP:%d Model Saved on:" % epoch, output_path)
|
1115 |
return output_path
|
1116 |
+
<<<<<<< HEAD
|
1117 |
+
|
1118 |
+
|
1119 |
+
class BERTAttention:
|
1120 |
+
def __init__(self, bert: BERT, vocab_obj, train_dataloader: DataLoader, workspace_name=None, code=None, finetune_task=None, with_cuda=True):
|
1121 |
+
|
1122 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
1123 |
+
|
1124 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
1125 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
1126 |
+
print(with_cuda, cuda_condition, " Device used = ", self.device)
|
1127 |
+
self.bert = bert.to(self.device)
|
1128 |
+
|
1129 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
1130 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
1131 |
+
# self.bert = nn.DataParallel(self.bert, device_ids=available_gpus)
|
1132 |
+
|
1133 |
+
self.train_dataloader = train_dataloader
|
1134 |
+
self.workspace_name = workspace_name
|
1135 |
+
self.code = code
|
1136 |
+
self.finetune_task = finetune_task
|
1137 |
+
self.vocab_obj = vocab_obj
|
1138 |
+
|
1139 |
+
def getAttention(self):
|
1140 |
+
# self.log_file = f"{self.workspace_name}/logs/{self.code}/log_attention.txt"
|
1141 |
+
|
1142 |
+
|
1143 |
+
labels = ['PercentChange', 'NumeratorQuantity2', 'NumeratorQuantity1', 'DenominatorQuantity1',
|
1144 |
+
'OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor',
|
1145 |
+
'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow',
|
1146 |
+
'ThirdRow', 'FinalAnswer','FinalAnswerDirection']
|
1147 |
+
df_all = pd.DataFrame(0.0, index=labels, columns=labels)
|
1148 |
+
# Setting the tqdm progress bar
|
1149 |
+
data_iter = tqdm.tqdm(enumerate(self.train_dataloader),
|
1150 |
+
desc="attention",
|
1151 |
+
total=len(self.train_dataloader),
|
1152 |
+
bar_format="{l_bar}{r_bar}")
|
1153 |
+
count = 0
|
1154 |
+
for i, data in data_iter:
|
1155 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
1156 |
+
a = self.bert.forward(data["bert_input"], data["segment_label"])
|
1157 |
+
non_zero = np.sum(data["segment_label"].cpu().detach().numpy())
|
1158 |
+
|
1159 |
+
# Last Transformer Layer
|
1160 |
+
last_layer = self.bert.attention_values[-1].transpose(1,0,2,3)
|
1161 |
+
# print(last_layer.shape)
|
1162 |
+
head, d_model, s, s = last_layer.shape
|
1163 |
+
|
1164 |
+
for d in range(d_model):
|
1165 |
+
seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
1166 |
+
# df_all = pd.DataFrame(0.0, index=seq_labels, columns=seq_labels)
|
1167 |
+
indices_to_choose = defaultdict(int)
|
1168 |
+
|
1169 |
+
for k,s in enumerate(seq_labels):
|
1170 |
+
if s in labels:
|
1171 |
+
indices_to_choose[s] = k
|
1172 |
+
indices_chosen = list(indices_to_choose.values())
|
1173 |
+
selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
1174 |
+
# print(len(seq_labels), len(selected_seq_labels))
|
1175 |
+
for h in range(head):
|
1176 |
+
# fig, ax = plt.subplots(figsize=(12, 12))
|
1177 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])#[1:non_zero-1]
|
1178 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
1179 |
+
# indices_to_choose = defaultdict(int)
|
1180 |
+
|
1181 |
+
# for k,s in enumerate(seq_labels):
|
1182 |
+
# if s in labels:
|
1183 |
+
# indices_to_choose[s] = k
|
1184 |
+
# indices_chosen = list(indices_to_choose.values())
|
1185 |
+
# selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
1186 |
+
# print(f"Chosen index: {seq_labels, indices_to_choose, indices_chosen, selected_seq_labels}")
|
1187 |
+
|
1188 |
+
df_cm = pd.DataFrame(last_layer[h][d][indices_chosen,:][:,indices_chosen], index = selected_seq_labels, columns = selected_seq_labels)
|
1189 |
+
df_all = df_all.add(df_cm, fill_value=0)
|
1190 |
+
count += 1
|
1191 |
+
|
1192 |
+
# df_cm = pd.DataFrame(last_layer[h][d][1:non_zero-1,:][:,1:non_zero-1], index=seq_labels, columns=seq_labels)
|
1193 |
+
# df_all = df_all.add(df_cm, fill_value=0)
|
1194 |
+
|
1195 |
+
# df_all = df_all.reindex(index=seq_labels, columns=seq_labels)
|
1196 |
+
# sns.heatmap(df_all, annot=False)
|
1197 |
+
# plt.title("Attentions") #Probabilities
|
1198 |
+
# plt.xlabel("Steps")
|
1199 |
+
# plt.ylabel("Steps")
|
1200 |
+
# plt.grid(True)
|
1201 |
+
# plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
1202 |
+
# plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores_over_[{h}]_head_n_data[{d}].png", bbox_inches='tight')
|
1203 |
+
# plt.show()
|
1204 |
+
# plt.close()
|
1205 |
+
|
1206 |
+
|
1207 |
+
|
1208 |
+
print(f"Count of total : {count, head * self.train_dataloader.dataset.len}")
|
1209 |
+
df_all = df_all.div(count) # head * self.train_dataloader.dataset.len
|
1210 |
+
df_all = df_all.reindex(index=labels, columns=labels)
|
1211 |
+
sns.heatmap(df_all, annot=False)
|
1212 |
+
plt.title("Attentions") #Probabilities
|
1213 |
+
plt.xlabel("Steps")
|
1214 |
+
plt.ylabel("Steps")
|
1215 |
+
plt.grid(True)
|
1216 |
+
plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
1217 |
+
plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores.png", bbox_inches='tight')
|
1218 |
+
plt.show()
|
1219 |
+
plt.close()
|
1220 |
+
|
1221 |
+
|
1222 |
+
|
1223 |
+
|
1224 |
+
=======
|
1225 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/reference_code/bert_reference_code.py
ADDED
@@ -0,0 +1,1622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model. """
|
17 |
+
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
import warnings
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from torch import nn
|
27 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
28 |
+
|
29 |
+
from .activations import gelu, gelu_new, swish
|
30 |
+
from .configuration_bert import BertConfig
|
31 |
+
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
32 |
+
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
38 |
+
|
39 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
40 |
+
"bert-base-uncased",
|
41 |
+
"bert-large-uncased",
|
42 |
+
"bert-base-cased",
|
43 |
+
"bert-large-cased",
|
44 |
+
"bert-base-multilingual-uncased",
|
45 |
+
"bert-base-multilingual-cased",
|
46 |
+
"bert-base-chinese",
|
47 |
+
"bert-base-german-cased",
|
48 |
+
"bert-large-uncased-whole-word-masking",
|
49 |
+
"bert-large-cased-whole-word-masking",
|
50 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
51 |
+
"bert-large-cased-whole-word-masking-finetuned-squad",
|
52 |
+
"bert-base-cased-finetuned-mrpc",
|
53 |
+
"bert-base-german-dbmdz-cased",
|
54 |
+
"bert-base-german-dbmdz-uncased",
|
55 |
+
"cl-tohoku/bert-base-japanese",
|
56 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
57 |
+
"cl-tohoku/bert-base-japanese-char",
|
58 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
59 |
+
"TurkuNLP/bert-base-finnish-cased-v1",
|
60 |
+
"TurkuNLP/bert-base-finnish-uncased-v1",
|
61 |
+
"wietsedv/bert-base-dutch-cased",
|
62 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
67 |
+
""" Load tf checkpoints in a pytorch model.
|
68 |
+
"""
|
69 |
+
try:
|
70 |
+
import re
|
71 |
+
import numpy as np
|
72 |
+
import tensorflow as tf
|
73 |
+
except ImportError:
|
74 |
+
logger.error(
|
75 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
76 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
77 |
+
)
|
78 |
+
raise
|
79 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
80 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
81 |
+
# Load weights from TF model
|
82 |
+
init_vars = tf.train.list_variables(tf_path)
|
83 |
+
names = []
|
84 |
+
arrays = []
|
85 |
+
for name, shape in init_vars:
|
86 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
87 |
+
array = tf.train.load_variable(tf_path, name)
|
88 |
+
names.append(name)
|
89 |
+
arrays.append(array)
|
90 |
+
|
91 |
+
for name, array in zip(names, arrays):
|
92 |
+
name = name.split("/")
|
93 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
94 |
+
# which are not required for using pretrained model
|
95 |
+
if any(
|
96 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
97 |
+
for n in name
|
98 |
+
):
|
99 |
+
logger.info("Skipping {}".format("/".join(name)))
|
100 |
+
continue
|
101 |
+
pointer = model
|
102 |
+
for m_name in name:
|
103 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
104 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
105 |
+
else:
|
106 |
+
scope_names = [m_name]
|
107 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
108 |
+
pointer = getattr(pointer, "weight")
|
109 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
110 |
+
pointer = getattr(pointer, "bias")
|
111 |
+
elif scope_names[0] == "output_weights":
|
112 |
+
pointer = getattr(pointer, "weight")
|
113 |
+
elif scope_names[0] == "squad":
|
114 |
+
pointer = getattr(pointer, "classifier")
|
115 |
+
else:
|
116 |
+
try:
|
117 |
+
pointer = getattr(pointer, scope_names[0])
|
118 |
+
except AttributeError:
|
119 |
+
logger.info("Skipping {}".format("/".join(name)))
|
120 |
+
continue
|
121 |
+
if len(scope_names) >= 2:
|
122 |
+
num = int(scope_names[1])
|
123 |
+
pointer = pointer[num]
|
124 |
+
if m_name[-11:] == "_embeddings":
|
125 |
+
pointer = getattr(pointer, "weight")
|
126 |
+
elif m_name == "kernel":
|
127 |
+
array = np.transpose(array)
|
128 |
+
try:
|
129 |
+
assert pointer.shape == array.shape
|
130 |
+
except AssertionError as e:
|
131 |
+
e.args += (pointer.shape, array.shape)
|
132 |
+
raise
|
133 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
134 |
+
pointer.data = torch.from_numpy(array)
|
135 |
+
return model
|
136 |
+
|
137 |
+
|
138 |
+
def mish(x):
|
139 |
+
return x * torch.tanh(nn.functional.softplus(x))
|
140 |
+
|
141 |
+
|
142 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
|
143 |
+
|
144 |
+
|
145 |
+
BertLayerNorm = torch.nn.LayerNorm
|
146 |
+
|
147 |
+
|
148 |
+
class BertEmbeddings(nn.Module):
|
149 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self, config):
|
153 |
+
super().__init__()
|
154 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
155 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
156 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
157 |
+
|
158 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
159 |
+
# any TensorFlow checkpoint file
|
160 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
161 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
162 |
+
|
163 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
164 |
+
if input_ids is not None:
|
165 |
+
input_shape = input_ids.size()
|
166 |
+
else:
|
167 |
+
input_shape = inputs_embeds.size()[:-1]
|
168 |
+
|
169 |
+
seq_length = input_shape[1]
|
170 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
171 |
+
if position_ids is None:
|
172 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
173 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
174 |
+
if token_type_ids is None:
|
175 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
176 |
+
|
177 |
+
if inputs_embeds is None:
|
178 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
179 |
+
position_embeddings = self.position_embeddings(position_ids)
|
180 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
181 |
+
|
182 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
183 |
+
embeddings = self.LayerNorm(embeddings)
|
184 |
+
embeddings = self.dropout(embeddings)
|
185 |
+
return embeddings
|
186 |
+
|
187 |
+
|
188 |
+
class BertSelfAttention(nn.Module):
|
189 |
+
def __init__(self, config):
|
190 |
+
super().__init__()
|
191 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
192 |
+
raise ValueError(
|
193 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
194 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
195 |
+
)
|
196 |
+
|
197 |
+
self.num_attention_heads = config.num_attention_heads
|
198 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
199 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
200 |
+
|
201 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
202 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
203 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
204 |
+
|
205 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
206 |
+
|
207 |
+
def transpose_for_scores(self, x):
|
208 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
209 |
+
x = x.view(*new_x_shape)
|
210 |
+
return x.permute(0, 2, 1, 3)
|
211 |
+
|
212 |
+
def forward(
|
213 |
+
self,
|
214 |
+
hidden_states,
|
215 |
+
attention_mask=None,
|
216 |
+
head_mask=None,
|
217 |
+
encoder_hidden_states=None,
|
218 |
+
encoder_attention_mask=None,
|
219 |
+
output_attentions=False,
|
220 |
+
):
|
221 |
+
mixed_query_layer = self.query(hidden_states)
|
222 |
+
|
223 |
+
# If this is instantiated as a cross-attention module, the keys
|
224 |
+
# and values come from an encoder; the attention mask needs to be
|
225 |
+
# such that the encoder's padding tokens are not attended to.
|
226 |
+
if encoder_hidden_states is not None:
|
227 |
+
mixed_key_layer = self.key(encoder_hidden_states)
|
228 |
+
mixed_value_layer = self.value(encoder_hidden_states)
|
229 |
+
attention_mask = encoder_attention_mask
|
230 |
+
else:
|
231 |
+
mixed_key_layer = self.key(hidden_states)
|
232 |
+
mixed_value_layer = self.value(hidden_states)
|
233 |
+
|
234 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
235 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
236 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
237 |
+
|
238 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
239 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
240 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
241 |
+
if attention_mask is not None:
|
242 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
243 |
+
attention_scores = attention_scores + attention_mask
|
244 |
+
|
245 |
+
# Normalize the attention scores to probabilities.
|
246 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
247 |
+
|
248 |
+
# This is actually dropping out entire tokens to attend to, which might
|
249 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
250 |
+
attention_probs = self.dropout(attention_probs)
|
251 |
+
|
252 |
+
# Mask heads if we want to
|
253 |
+
if head_mask is not None:
|
254 |
+
attention_probs = attention_probs * head_mask
|
255 |
+
|
256 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
257 |
+
|
258 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
259 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
260 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
261 |
+
|
262 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
263 |
+
return outputs
|
264 |
+
|
265 |
+
|
266 |
+
class BertSelfOutput(nn.Module):
|
267 |
+
def __init__(self, config):
|
268 |
+
super().__init__()
|
269 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
270 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
271 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
272 |
+
|
273 |
+
def forward(self, hidden_states, input_tensor):
|
274 |
+
hidden_states = self.dense(hidden_states)
|
275 |
+
hidden_states = self.dropout(hidden_states)
|
276 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
277 |
+
return hidden_states
|
278 |
+
|
279 |
+
|
280 |
+
class BertAttention(nn.Module):
|
281 |
+
def __init__(self, config):
|
282 |
+
super().__init__()
|
283 |
+
self.self = BertSelfAttention(config)
|
284 |
+
self.output = BertSelfOutput(config)
|
285 |
+
self.pruned_heads = set()
|
286 |
+
|
287 |
+
def prune_heads(self, heads):
|
288 |
+
if len(heads) == 0:
|
289 |
+
return
|
290 |
+
heads, index = find_pruneable_heads_and_indices(
|
291 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
292 |
+
)
|
293 |
+
|
294 |
+
# Prune linear layers
|
295 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
296 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
297 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
298 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
299 |
+
|
300 |
+
# Update hyper params and store pruned heads
|
301 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
302 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
303 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
304 |
+
|
305 |
+
def forward(
|
306 |
+
self,
|
307 |
+
hidden_states,
|
308 |
+
attention_mask=None,
|
309 |
+
head_mask=None,
|
310 |
+
encoder_hidden_states=None,
|
311 |
+
encoder_attention_mask=None,
|
312 |
+
output_attentions=False,
|
313 |
+
):
|
314 |
+
self_outputs = self.self(
|
315 |
+
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
|
316 |
+
)
|
317 |
+
|
318 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
319 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
320 |
+
return outputs
|
321 |
+
|
322 |
+
|
323 |
+
class BertIntermediate(nn.Module):
|
324 |
+
def __init__(self, config):
|
325 |
+
super().__init__()
|
326 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
327 |
+
if isinstance(config.hidden_act, str):
|
328 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
329 |
+
else:
|
330 |
+
self.intermediate_act_fn = config.hidden_act
|
331 |
+
|
332 |
+
def forward(self, hidden_states):
|
333 |
+
hidden_states = self.dense(hidden_states)
|
334 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
335 |
+
return hidden_states
|
336 |
+
|
337 |
+
|
338 |
+
class BertOutput(nn.Module):
|
339 |
+
def __init__(self, config):
|
340 |
+
super().__init__()
|
341 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
342 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
343 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
344 |
+
|
345 |
+
def forward(self, hidden_states, input_tensor):
|
346 |
+
hidden_states = self.dense(hidden_states)
|
347 |
+
hidden_states = self.dropout(hidden_states)
|
348 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
349 |
+
return hidden_states
|
350 |
+
|
351 |
+
|
352 |
+
class BertLayer(nn.Module):
|
353 |
+
def __init__(self, config):
|
354 |
+
super().__init__()
|
355 |
+
self.attention = BertAttention(config)
|
356 |
+
self.is_decoder = config.is_decoder
|
357 |
+
if self.is_decoder:
|
358 |
+
self.crossattention = BertAttention(config)
|
359 |
+
self.intermediate = BertIntermediate(config)
|
360 |
+
self.output = BertOutput(config)
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
hidden_states,
|
365 |
+
attention_mask=None,
|
366 |
+
head_mask=None,
|
367 |
+
encoder_hidden_states=None,
|
368 |
+
encoder_attention_mask=None,
|
369 |
+
output_attentions=False,
|
370 |
+
):
|
371 |
+
self_attention_outputs = self.attention(
|
372 |
+
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
|
373 |
+
)
|
374 |
+
attention_output = self_attention_outputs[0]
|
375 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
376 |
+
|
377 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
378 |
+
cross_attention_outputs = self.crossattention(
|
379 |
+
attention_output,
|
380 |
+
attention_mask,
|
381 |
+
head_mask,
|
382 |
+
encoder_hidden_states,
|
383 |
+
encoder_attention_mask,
|
384 |
+
output_attentions,
|
385 |
+
)
|
386 |
+
attention_output = cross_attention_outputs[0]
|
387 |
+
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
388 |
+
|
389 |
+
intermediate_output = self.intermediate(attention_output)
|
390 |
+
layer_output = self.output(intermediate_output, attention_output)
|
391 |
+
outputs = (layer_output,) + outputs
|
392 |
+
return outputs
|
393 |
+
|
394 |
+
|
395 |
+
class BertEncoder(nn.Module):
|
396 |
+
def __init__(self, config):
|
397 |
+
super().__init__()
|
398 |
+
self.config = config
|
399 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
400 |
+
|
401 |
+
def forward(
|
402 |
+
self,
|
403 |
+
hidden_states,
|
404 |
+
attention_mask=None,
|
405 |
+
head_mask=None,
|
406 |
+
encoder_hidden_states=None,
|
407 |
+
encoder_attention_mask=None,
|
408 |
+
output_attentions=False,
|
409 |
+
output_hidden_states=False,
|
410 |
+
):
|
411 |
+
all_hidden_states = ()
|
412 |
+
all_attentions = ()
|
413 |
+
for i, layer_module in enumerate(self.layer):
|
414 |
+
if output_hidden_states:
|
415 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
416 |
+
|
417 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
418 |
+
|
419 |
+
def create_custom_forward(module):
|
420 |
+
def custom_forward(*inputs):
|
421 |
+
return module(*inputs, output_attentions)
|
422 |
+
|
423 |
+
return custom_forward
|
424 |
+
|
425 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
426 |
+
create_custom_forward(layer_module),
|
427 |
+
hidden_states,
|
428 |
+
attention_mask,
|
429 |
+
head_mask[i],
|
430 |
+
encoder_hidden_states,
|
431 |
+
encoder_attention_mask,
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
layer_outputs = layer_module(
|
435 |
+
hidden_states,
|
436 |
+
attention_mask,
|
437 |
+
head_mask[i],
|
438 |
+
encoder_hidden_states,
|
439 |
+
encoder_attention_mask,
|
440 |
+
output_attentions,
|
441 |
+
)
|
442 |
+
hidden_states = layer_outputs[0]
|
443 |
+
|
444 |
+
if output_attentions:
|
445 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
446 |
+
|
447 |
+
# Add last layer
|
448 |
+
if output_hidden_states:
|
449 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
450 |
+
|
451 |
+
outputs = (hidden_states,)
|
452 |
+
if output_hidden_states:
|
453 |
+
outputs = outputs + (all_hidden_states,)
|
454 |
+
if output_attentions:
|
455 |
+
outputs = outputs + (all_attentions,)
|
456 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
457 |
+
|
458 |
+
|
459 |
+
class BertPooler(nn.Module):
|
460 |
+
def __init__(self, config):
|
461 |
+
super().__init__()
|
462 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
463 |
+
self.activation = nn.Tanh()
|
464 |
+
|
465 |
+
def forward(self, hidden_states):
|
466 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
467 |
+
# to the first token.
|
468 |
+
first_token_tensor = hidden_states[:, 0]
|
469 |
+
pooled_output = self.dense(first_token_tensor)
|
470 |
+
pooled_output = self.activation(pooled_output)
|
471 |
+
return pooled_output
|
472 |
+
|
473 |
+
|
474 |
+
class BertPredictionHeadTransform(nn.Module):
|
475 |
+
def __init__(self, config):
|
476 |
+
super().__init__()
|
477 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
478 |
+
if isinstance(config.hidden_act, str):
|
479 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
480 |
+
else:
|
481 |
+
self.transform_act_fn = config.hidden_act
|
482 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
483 |
+
|
484 |
+
def forward(self, hidden_states):
|
485 |
+
hidden_states = self.dense(hidden_states)
|
486 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
487 |
+
hidden_states = self.LayerNorm(hidden_states)
|
488 |
+
return hidden_states
|
489 |
+
|
490 |
+
|
491 |
+
class BertLMPredictionHead(nn.Module):
|
492 |
+
def __init__(self, config):
|
493 |
+
super().__init__()
|
494 |
+
self.transform = BertPredictionHeadTransform(config)
|
495 |
+
|
496 |
+
# The output weights are the same as the input embeddings, but there is
|
497 |
+
# an output-only bias for each token.
|
498 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
499 |
+
|
500 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
501 |
+
|
502 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
503 |
+
self.decoder.bias = self.bias
|
504 |
+
|
505 |
+
def forward(self, hidden_states):
|
506 |
+
hidden_states = self.transform(hidden_states)
|
507 |
+
hidden_states = self.decoder(hidden_states)
|
508 |
+
return hidden_states
|
509 |
+
|
510 |
+
|
511 |
+
class BertOnlyMLMHead(nn.Module):
|
512 |
+
def __init__(self, config):
|
513 |
+
super().__init__()
|
514 |
+
self.predictions = BertLMPredictionHead(config)
|
515 |
+
|
516 |
+
def forward(self, sequence_output):
|
517 |
+
prediction_scores = self.predictions(sequence_output)
|
518 |
+
return prediction_scores
|
519 |
+
|
520 |
+
|
521 |
+
class BertOnlyNSPHead(nn.Module):
|
522 |
+
def __init__(self, config):
|
523 |
+
super().__init__()
|
524 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
525 |
+
|
526 |
+
def forward(self, pooled_output):
|
527 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
528 |
+
return seq_relationship_score
|
529 |
+
|
530 |
+
|
531 |
+
class BertPreTrainingHeads(nn.Module):
|
532 |
+
def __init__(self, config):
|
533 |
+
super().__init__()
|
534 |
+
self.predictions = BertLMPredictionHead(config)
|
535 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
536 |
+
|
537 |
+
def forward(self, sequence_output, pooled_output):
|
538 |
+
prediction_scores = self.predictions(sequence_output)
|
539 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
540 |
+
return prediction_scores, seq_relationship_score
|
541 |
+
|
542 |
+
|
543 |
+
class BertPreTrainedModel(PreTrainedModel):
|
544 |
+
""" An abstract class to handle weights initialization and
|
545 |
+
a simple interface for downloading and loading pretrained models.
|
546 |
+
"""
|
547 |
+
|
548 |
+
config_class = BertConfig
|
549 |
+
load_tf_weights = load_tf_weights_in_bert
|
550 |
+
base_model_prefix = "bert"
|
551 |
+
|
552 |
+
def _init_weights(self, module):
|
553 |
+
""" Initialize the weights """
|
554 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
555 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
556 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
557 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
558 |
+
elif isinstance(module, BertLayerNorm):
|
559 |
+
module.bias.data.zero_()
|
560 |
+
module.weight.data.fill_(1.0)
|
561 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
562 |
+
module.bias.data.zero_()
|
563 |
+
|
564 |
+
|
565 |
+
BERT_START_DOCSTRING = r"""
|
566 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
567 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
568 |
+
usage and behavior.
|
569 |
+
|
570 |
+
Parameters:
|
571 |
+
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
572 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
573 |
+
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
574 |
+
"""
|
575 |
+
|
576 |
+
BERT_INPUTS_DOCSTRING = r"""
|
577 |
+
Args:
|
578 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
579 |
+
Indices of input sequence tokens in the vocabulary.
|
580 |
+
|
581 |
+
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
582 |
+
See :func:`transformers.PreTrainedTokenizer.encode` and
|
583 |
+
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
584 |
+
|
585 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
586 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
587 |
+
Mask to avoid performing attention on padding token indices.
|
588 |
+
Mask values selected in ``[0, 1]``:
|
589 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
590 |
+
|
591 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
592 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
593 |
+
Segment token indices to indicate first and second portions of the inputs.
|
594 |
+
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
595 |
+
corresponds to a `sentence B` token
|
596 |
+
|
597 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
598 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
599 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
600 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
601 |
+
|
602 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
603 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
604 |
+
Mask to nullify selected heads of the self-attention modules.
|
605 |
+
Mask values selected in ``[0, 1]``:
|
606 |
+
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
607 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
608 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
609 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
610 |
+
than the model's internal embedding lookup matrix.
|
611 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
612 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
613 |
+
if the model is configured as a decoder.
|
614 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
615 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
616 |
+
is used in the cross-attention if the model is configured as a decoder.
|
617 |
+
Mask values selected in ``[0, 1]``:
|
618 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
619 |
+
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
620 |
+
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
621 |
+
"""
|
622 |
+
|
623 |
+
|
624 |
+
|
625 |
+
[DOCS]
|
626 |
+
@add_start_docstrings(
|
627 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
628 |
+
BERT_START_DOCSTRING,
|
629 |
+
)
|
630 |
+
class BertModel(BertPreTrainedModel):
|
631 |
+
"""
|
632 |
+
|
633 |
+
The model can behave as an encoder (with only self-attention) as well
|
634 |
+
as a decoder, in which case a layer of cross-attention is added between
|
635 |
+
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
636 |
+
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
637 |
+
|
638 |
+
To behave as an decoder the model needs to be initialized with the
|
639 |
+
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
640 |
+
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
641 |
+
|
642 |
+
.. _`Attention is all you need`:
|
643 |
+
https://arxiv.org/abs/1706.03762
|
644 |
+
|
645 |
+
"""
|
646 |
+
|
647 |
+
def __init__(self, config):
|
648 |
+
super().__init__(config)
|
649 |
+
self.config = config
|
650 |
+
|
651 |
+
self.embeddings = BertEmbeddings(config)
|
652 |
+
self.encoder = BertEncoder(config)
|
653 |
+
self.pooler = BertPooler(config)
|
654 |
+
|
655 |
+
self.init_weights()
|
656 |
+
|
657 |
+
|
658 |
+
[DOCS]
|
659 |
+
def get_input_embeddings(self):
|
660 |
+
return self.embeddings.word_embeddings
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
+
[DOCS]
|
665 |
+
def set_input_embeddings(self, value):
|
666 |
+
self.embeddings.word_embeddings = value
|
667 |
+
|
668 |
+
|
669 |
+
def _prune_heads(self, heads_to_prune):
|
670 |
+
""" Prunes heads of the model.
|
671 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
672 |
+
See base class PreTrainedModel
|
673 |
+
"""
|
674 |
+
for layer, heads in heads_to_prune.items():
|
675 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
676 |
+
|
677 |
+
|
678 |
+
[DOCS]
|
679 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
680 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
681 |
+
def forward(
|
682 |
+
self,
|
683 |
+
input_ids=None,
|
684 |
+
attention_mask=None,
|
685 |
+
token_type_ids=None,
|
686 |
+
position_ids=None,
|
687 |
+
head_mask=None,
|
688 |
+
inputs_embeds=None,
|
689 |
+
encoder_hidden_states=None,
|
690 |
+
encoder_attention_mask=None,
|
691 |
+
output_attentions=None,
|
692 |
+
output_hidden_states=None,
|
693 |
+
):
|
694 |
+
r"""
|
695 |
+
Return:
|
696 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
697 |
+
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
698 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
699 |
+
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
700 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
701 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
702 |
+
layer weights are trained from the next sentence prediction (classification)
|
703 |
+
objective during pre-training.
|
704 |
+
|
705 |
+
This output is usually *not* a good summary
|
706 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
707 |
+
the sequence of hidden-states for the whole input sequence.
|
708 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
709 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
710 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
711 |
+
|
712 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
713 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
714 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
715 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
716 |
+
|
717 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
718 |
+
heads.
|
719 |
+
"""
|
720 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
721 |
+
output_hidden_states = (
|
722 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
723 |
+
)
|
724 |
+
|
725 |
+
if input_ids is not None and inputs_embeds is not None:
|
726 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
727 |
+
elif input_ids is not None:
|
728 |
+
input_shape = input_ids.size()
|
729 |
+
elif inputs_embeds is not None:
|
730 |
+
input_shape = inputs_embeds.size()[:-1]
|
731 |
+
else:
|
732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
733 |
+
|
734 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
735 |
+
|
736 |
+
if attention_mask is None:
|
737 |
+
attention_mask = torch.ones(input_shape, device=device)
|
738 |
+
if token_type_ids is None:
|
739 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
740 |
+
|
741 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
742 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
743 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
744 |
+
|
745 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
746 |
+
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
747 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
748 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
749 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
750 |
+
if encoder_attention_mask is None:
|
751 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
752 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
753 |
+
else:
|
754 |
+
encoder_extended_attention_mask = None
|
755 |
+
|
756 |
+
# Prepare head mask if needed
|
757 |
+
# 1.0 in head_mask indicate we keep the head
|
758 |
+
# attention_probs has shape bsz x n_heads x N x N
|
759 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
760 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
761 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
762 |
+
|
763 |
+
embedding_output = self.embeddings(
|
764 |
+
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
765 |
+
)
|
766 |
+
encoder_outputs = self.encoder(
|
767 |
+
embedding_output,
|
768 |
+
attention_mask=extended_attention_mask,
|
769 |
+
head_mask=head_mask,
|
770 |
+
encoder_hidden_states=encoder_hidden_states,
|
771 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
772 |
+
output_attentions=output_attentions,
|
773 |
+
output_hidden_states=output_hidden_states,
|
774 |
+
)
|
775 |
+
sequence_output = encoder_outputs[0]
|
776 |
+
pooled_output = self.pooler(sequence_output)
|
777 |
+
|
778 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
779 |
+
1:
|
780 |
+
] # add hidden_states and attentions if they are here
|
781 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
782 |
+
|
783 |
+
|
784 |
+
|
785 |
+
|
786 |
+
[DOCS]
|
787 |
+
@add_start_docstrings(
|
788 |
+
"""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
|
789 |
+
a `next sentence prediction (classification)` head. """,
|
790 |
+
BERT_START_DOCSTRING,
|
791 |
+
)
|
792 |
+
class BertForPreTraining(BertPreTrainedModel):
|
793 |
+
def __init__(self, config):
|
794 |
+
super().__init__(config)
|
795 |
+
|
796 |
+
self.bert = BertModel(config)
|
797 |
+
self.cls = BertPreTrainingHeads(config)
|
798 |
+
|
799 |
+
self.init_weights()
|
800 |
+
|
801 |
+
|
802 |
+
[DOCS]
|
803 |
+
def get_output_embeddings(self):
|
804 |
+
return self.cls.predictions.decoder
|
805 |
+
|
806 |
+
|
807 |
+
|
808 |
+
[DOCS]
|
809 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
810 |
+
def forward(
|
811 |
+
self,
|
812 |
+
input_ids=None,
|
813 |
+
attention_mask=None,
|
814 |
+
token_type_ids=None,
|
815 |
+
position_ids=None,
|
816 |
+
head_mask=None,
|
817 |
+
inputs_embeds=None,
|
818 |
+
labels=None,
|
819 |
+
next_sentence_label=None,
|
820 |
+
output_attentions=None,
|
821 |
+
output_hidden_states=None,
|
822 |
+
**kwargs
|
823 |
+
):
|
824 |
+
r"""
|
825 |
+
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
826 |
+
Labels for computing the masked language modeling loss.
|
827 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
828 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
829 |
+
in ``[0, ..., config.vocab_size]``
|
830 |
+
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
|
831 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
|
832 |
+
Indices should be in ``[0, 1]``.
|
833 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
834 |
+
``1`` indicates sequence B is a random sequence.
|
835 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
836 |
+
Used to hide legacy arguments that have been deprecated.
|
837 |
+
|
838 |
+
Returns:
|
839 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
840 |
+
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
841 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
842 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
843 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
844 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
845 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False
|
846 |
+
continuation before SoftMax).
|
847 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
848 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
849 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
850 |
+
|
851 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
852 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
853 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
854 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
855 |
+
|
856 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
857 |
+
heads.
|
858 |
+
|
859 |
+
|
860 |
+
Examples::
|
861 |
+
|
862 |
+
>>> from transformers import BertTokenizer, BertForPreTraining
|
863 |
+
>>> import torch
|
864 |
+
|
865 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
866 |
+
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
|
867 |
+
|
868 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
869 |
+
>>> outputs = model(**inputs)
|
870 |
+
|
871 |
+
>>> prediction_scores, seq_relationship_scores = outputs[:2]
|
872 |
+
|
873 |
+
"""
|
874 |
+
if "masked_lm_labels" in kwargs:
|
875 |
+
warnings.warn(
|
876 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
877 |
+
DeprecationWarning,
|
878 |
+
)
|
879 |
+
labels = kwargs.pop("masked_lm_labels")
|
880 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
881 |
+
|
882 |
+
outputs = self.bert(
|
883 |
+
input_ids,
|
884 |
+
attention_mask=attention_mask,
|
885 |
+
token_type_ids=token_type_ids,
|
886 |
+
position_ids=position_ids,
|
887 |
+
head_mask=head_mask,
|
888 |
+
inputs_embeds=inputs_embeds,
|
889 |
+
output_attentions=output_attentions,
|
890 |
+
output_hidden_states=output_hidden_states,
|
891 |
+
)
|
892 |
+
|
893 |
+
sequence_output, pooled_output = outputs[:2]
|
894 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
895 |
+
|
896 |
+
outputs = (prediction_scores, seq_relationship_score,) + outputs[
|
897 |
+
2:
|
898 |
+
] # add hidden states and attention if they are here
|
899 |
+
|
900 |
+
if labels is not None and next_sentence_label is not None:
|
901 |
+
loss_fct = CrossEntropyLoss()
|
902 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
903 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
904 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
905 |
+
outputs = (total_loss,) + outputs
|
906 |
+
|
907 |
+
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
908 |
+
|
909 |
+
|
910 |
+
|
911 |
+
@add_start_docstrings(
|
912 |
+
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
|
913 |
+
)
|
914 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
915 |
+
def __init__(self, config):
|
916 |
+
super().__init__(config)
|
917 |
+
assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
|
918 |
+
|
919 |
+
self.bert = BertModel(config)
|
920 |
+
self.cls = BertOnlyMLMHead(config)
|
921 |
+
|
922 |
+
self.init_weights()
|
923 |
+
|
924 |
+
def get_output_embeddings(self):
|
925 |
+
return self.cls.predictions.decoder
|
926 |
+
|
927 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
928 |
+
def forward(
|
929 |
+
self,
|
930 |
+
input_ids=None,
|
931 |
+
attention_mask=None,
|
932 |
+
token_type_ids=None,
|
933 |
+
position_ids=None,
|
934 |
+
head_mask=None,
|
935 |
+
inputs_embeds=None,
|
936 |
+
labels=None,
|
937 |
+
encoder_hidden_states=None,
|
938 |
+
encoder_attention_mask=None,
|
939 |
+
output_attentions=None,
|
940 |
+
output_hidden_states=None,
|
941 |
+
**kwargs
|
942 |
+
):
|
943 |
+
r"""
|
944 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
945 |
+
Labels for computing the left-to-right language modeling loss (next word prediction).
|
946 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
947 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
948 |
+
in ``[0, ..., config.vocab_size]``
|
949 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
950 |
+
Used to hide legacy arguments that have been deprecated.
|
951 |
+
|
952 |
+
Returns:
|
953 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
954 |
+
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
955 |
+
Next token prediction loss.
|
956 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
957 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
958 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
959 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
960 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
961 |
+
|
962 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
963 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
964 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
965 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
966 |
+
|
967 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
968 |
+
heads.
|
969 |
+
|
970 |
+
Example::
|
971 |
+
|
972 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
973 |
+
>>> import torch
|
974 |
+
|
975 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
976 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
977 |
+
>>> config.is_decoder = True
|
978 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
979 |
+
|
980 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
981 |
+
>>> outputs = model(**inputs)
|
982 |
+
|
983 |
+
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
984 |
+
"""
|
985 |
+
|
986 |
+
outputs = self.bert(
|
987 |
+
input_ids,
|
988 |
+
attention_mask=attention_mask,
|
989 |
+
token_type_ids=token_type_ids,
|
990 |
+
position_ids=position_ids,
|
991 |
+
head_mask=head_mask,
|
992 |
+
inputs_embeds=inputs_embeds,
|
993 |
+
encoder_hidden_states=encoder_hidden_states,
|
994 |
+
encoder_attention_mask=encoder_attention_mask,
|
995 |
+
output_attentions=output_attentions,
|
996 |
+
output_hidden_states=output_hidden_states,
|
997 |
+
)
|
998 |
+
|
999 |
+
sequence_output = outputs[0]
|
1000 |
+
prediction_scores = self.cls(sequence_output)
|
1001 |
+
|
1002 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
1003 |
+
|
1004 |
+
if labels is not None:
|
1005 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1006 |
+
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1007 |
+
labels = labels[:, 1:].contiguous()
|
1008 |
+
loss_fct = CrossEntropyLoss()
|
1009 |
+
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1010 |
+
outputs = (ltr_lm_loss,) + outputs
|
1011 |
+
|
1012 |
+
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
1013 |
+
|
1014 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
1015 |
+
input_shape = input_ids.shape
|
1016 |
+
|
1017 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1018 |
+
if attention_mask is None:
|
1019 |
+
attention_mask = input_ids.new_ones(input_shape)
|
1020 |
+
|
1021 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
1022 |
+
|
1023 |
+
|
1024 |
+
|
1025 |
+
[DOCS]
|
1026 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
1027 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1028 |
+
def __init__(self, config):
|
1029 |
+
super().__init__(config)
|
1030 |
+
assert (
|
1031 |
+
not config.is_decoder
|
1032 |
+
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
1033 |
+
|
1034 |
+
self.bert = BertModel(config)
|
1035 |
+
self.cls = BertOnlyMLMHead(config)
|
1036 |
+
|
1037 |
+
self.init_weights()
|
1038 |
+
|
1039 |
+
|
1040 |
+
[DOCS]
|
1041 |
+
def get_output_embeddings(self):
|
1042 |
+
return self.cls.predictions.decoder
|
1043 |
+
|
1044 |
+
|
1045 |
+
|
1046 |
+
[DOCS]
|
1047 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
1048 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
1049 |
+
def forward(
|
1050 |
+
self,
|
1051 |
+
input_ids=None,
|
1052 |
+
attention_mask=None,
|
1053 |
+
token_type_ids=None,
|
1054 |
+
position_ids=None,
|
1055 |
+
head_mask=None,
|
1056 |
+
inputs_embeds=None,
|
1057 |
+
labels=None,
|
1058 |
+
encoder_hidden_states=None,
|
1059 |
+
encoder_attention_mask=None,
|
1060 |
+
output_attentions=None,
|
1061 |
+
output_hidden_states=None,
|
1062 |
+
**kwargs
|
1063 |
+
):
|
1064 |
+
r"""
|
1065 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
1066 |
+
Labels for computing the masked language modeling loss.
|
1067 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
1068 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
1069 |
+
in ``[0, ..., config.vocab_size]``
|
1070 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
1071 |
+
Used to hide legacy arguments that have been deprecated.
|
1072 |
+
|
1073 |
+
Returns:
|
1074 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1075 |
+
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
1076 |
+
Masked language modeling loss.
|
1077 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
1078 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
1079 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
1080 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1081 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1082 |
+
|
1083 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1084 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
1085 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1086 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1087 |
+
|
1088 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1089 |
+
heads.
|
1090 |
+
"""
|
1091 |
+
if "masked_lm_labels" in kwargs:
|
1092 |
+
warnings.warn(
|
1093 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
1094 |
+
DeprecationWarning,
|
1095 |
+
)
|
1096 |
+
labels = kwargs.pop("masked_lm_labels")
|
1097 |
+
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
1098 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
1099 |
+
|
1100 |
+
outputs = self.bert(
|
1101 |
+
input_ids,
|
1102 |
+
attention_mask=attention_mask,
|
1103 |
+
token_type_ids=token_type_ids,
|
1104 |
+
position_ids=position_ids,
|
1105 |
+
head_mask=head_mask,
|
1106 |
+
inputs_embeds=inputs_embeds,
|
1107 |
+
encoder_hidden_states=encoder_hidden_states,
|
1108 |
+
encoder_attention_mask=encoder_attention_mask,
|
1109 |
+
output_attentions=output_attentions,
|
1110 |
+
output_hidden_states=output_hidden_states,
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
sequence_output = outputs[0]
|
1114 |
+
prediction_scores = self.cls(sequence_output)
|
1115 |
+
|
1116 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
1117 |
+
|
1118 |
+
if labels is not None:
|
1119 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1120 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1121 |
+
outputs = (masked_lm_loss,) + outputs
|
1122 |
+
|
1123 |
+
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
1124 |
+
|
1125 |
+
|
1126 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
1127 |
+
input_shape = input_ids.shape
|
1128 |
+
effective_batch_size = input_shape[0]
|
1129 |
+
|
1130 |
+
# add a dummy token
|
1131 |
+
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
|
1132 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
1133 |
+
dummy_token = torch.full(
|
1134 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
1135 |
+
)
|
1136 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
1137 |
+
|
1138 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
1139 |
+
|
1140 |
+
|
1141 |
+
|
1142 |
+
|
1143 |
+
[DOCS]
|
1144 |
+
@add_start_docstrings(
|
1145 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
|
1146 |
+
)
|
1147 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
1148 |
+
def __init__(self, config):
|
1149 |
+
super().__init__(config)
|
1150 |
+
|
1151 |
+
self.bert = BertModel(config)
|
1152 |
+
self.cls = BertOnlyNSPHead(config)
|
1153 |
+
|
1154 |
+
self.init_weights()
|
1155 |
+
|
1156 |
+
|
1157 |
+
[DOCS]
|
1158 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
1159 |
+
def forward(
|
1160 |
+
self,
|
1161 |
+
input_ids=None,
|
1162 |
+
attention_mask=None,
|
1163 |
+
token_type_ids=None,
|
1164 |
+
position_ids=None,
|
1165 |
+
head_mask=None,
|
1166 |
+
inputs_embeds=None,
|
1167 |
+
next_sentence_label=None,
|
1168 |
+
output_attentions=None,
|
1169 |
+
output_hidden_states=None,
|
1170 |
+
):
|
1171 |
+
r"""
|
1172 |
+
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1173 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
1174 |
+
Indices should be in ``[0, 1]``.
|
1175 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
1176 |
+
``1`` indicates sequence B is a random sequence.
|
1177 |
+
|
1178 |
+
Returns:
|
1179 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1180 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
1181 |
+
Next sequence prediction (classification) loss.
|
1182 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
1183 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
1184 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
1185 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1186 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1187 |
+
|
1188 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1189 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
1190 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1191 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1192 |
+
|
1193 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1194 |
+
heads.
|
1195 |
+
|
1196 |
+
Examples::
|
1197 |
+
|
1198 |
+
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
|
1199 |
+
>>> import torch
|
1200 |
+
|
1201 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1202 |
+
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
1203 |
+
|
1204 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1205 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
1206 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
1207 |
+
|
1208 |
+
>>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
1209 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
1210 |
+
"""
|
1211 |
+
|
1212 |
+
outputs = self.bert(
|
1213 |
+
input_ids,
|
1214 |
+
attention_mask=attention_mask,
|
1215 |
+
token_type_ids=token_type_ids,
|
1216 |
+
position_ids=position_ids,
|
1217 |
+
head_mask=head_mask,
|
1218 |
+
inputs_embeds=inputs_embeds,
|
1219 |
+
output_attentions=output_attentions,
|
1220 |
+
output_hidden_states=output_hidden_states,
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
pooled_output = outputs[1]
|
1224 |
+
|
1225 |
+
seq_relationship_score = self.cls(pooled_output)
|
1226 |
+
|
1227 |
+
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
1228 |
+
if next_sentence_label is not None:
|
1229 |
+
loss_fct = CrossEntropyLoss()
|
1230 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
1231 |
+
outputs = (next_sentence_loss,) + outputs
|
1232 |
+
|
1233 |
+
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
1234 |
+
|
1235 |
+
|
1236 |
+
|
1237 |
+
|
1238 |
+
[DOCS]
|
1239 |
+
@add_start_docstrings(
|
1240 |
+
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
1241 |
+
the pooled output) e.g. for GLUE tasks. """,
|
1242 |
+
BERT_START_DOCSTRING,
|
1243 |
+
)
|
1244 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
1245 |
+
def __init__(self, config):
|
1246 |
+
super().__init__(config)
|
1247 |
+
self.num_labels = config.num_labels
|
1248 |
+
|
1249 |
+
self.bert = BertModel(config)
|
1250 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1251 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1252 |
+
|
1253 |
+
self.init_weights()
|
1254 |
+
|
1255 |
+
|
1256 |
+
[DOCS]
|
1257 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
1258 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
1259 |
+
def forward(
|
1260 |
+
self,
|
1261 |
+
input_ids=None,
|
1262 |
+
attention_mask=None,
|
1263 |
+
token_type_ids=None,
|
1264 |
+
position_ids=None,
|
1265 |
+
head_mask=None,
|
1266 |
+
inputs_embeds=None,
|
1267 |
+
labels=None,
|
1268 |
+
output_attentions=None,
|
1269 |
+
output_hidden_states=None,
|
1270 |
+
):
|
1271 |
+
r"""
|
1272 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1273 |
+
Labels for computing the sequence classification/regression loss.
|
1274 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
1275 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1276 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1277 |
+
|
1278 |
+
Returns:
|
1279 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1280 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
1281 |
+
Classification (or regression if config.num_labels==1) loss.
|
1282 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
1283 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
1284 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
1285 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1286 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1287 |
+
|
1288 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1289 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
1290 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1291 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1292 |
+
|
1293 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1294 |
+
heads.
|
1295 |
+
"""
|
1296 |
+
|
1297 |
+
outputs = self.bert(
|
1298 |
+
input_ids,
|
1299 |
+
attention_mask=attention_mask,
|
1300 |
+
token_type_ids=token_type_ids,
|
1301 |
+
position_ids=position_ids,
|
1302 |
+
head_mask=head_mask,
|
1303 |
+
inputs_embeds=inputs_embeds,
|
1304 |
+
output_attentions=output_attentions,
|
1305 |
+
output_hidden_states=output_hidden_states,
|
1306 |
+
)
|
1307 |
+
|
1308 |
+
pooled_output = outputs[1]
|
1309 |
+
|
1310 |
+
pooled_output = self.dropout(pooled_output)
|
1311 |
+
logits = self.classifier(pooled_output)
|
1312 |
+
|
1313 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
1314 |
+
|
1315 |
+
if labels is not None:
|
1316 |
+
if self.num_labels == 1:
|
1317 |
+
# We are doing regression
|
1318 |
+
loss_fct = MSELoss()
|
1319 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
1320 |
+
else:
|
1321 |
+
loss_fct = CrossEntropyLoss()
|
1322 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1323 |
+
outputs = (loss,) + outputs
|
1324 |
+
|
1325 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
1326 |
+
|
1327 |
+
|
1328 |
+
|
1329 |
+
|
1330 |
+
[DOCS]
|
1331 |
+
@add_start_docstrings(
|
1332 |
+
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
1333 |
+
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
1334 |
+
BERT_START_DOCSTRING,
|
1335 |
+
)
|
1336 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
1337 |
+
def __init__(self, config):
|
1338 |
+
super().__init__(config)
|
1339 |
+
|
1340 |
+
self.bert = BertModel(config)
|
1341 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1342 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
1343 |
+
|
1344 |
+
self.init_weights()
|
1345 |
+
|
1346 |
+
|
1347 |
+
[DOCS]
|
1348 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
1349 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
1350 |
+
def forward(
|
1351 |
+
self,
|
1352 |
+
input_ids=None,
|
1353 |
+
attention_mask=None,
|
1354 |
+
token_type_ids=None,
|
1355 |
+
position_ids=None,
|
1356 |
+
head_mask=None,
|
1357 |
+
inputs_embeds=None,
|
1358 |
+
labels=None,
|
1359 |
+
output_attentions=None,
|
1360 |
+
output_hidden_states=None,
|
1361 |
+
):
|
1362 |
+
r"""
|
1363 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1364 |
+
Labels for computing the multiple choice classification loss.
|
1365 |
+
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
1366 |
+
of the input tensors. (see `input_ids` above)
|
1367 |
+
|
1368 |
+
Returns:
|
1369 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1370 |
+
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
1371 |
+
Classification loss.
|
1372 |
+
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
1373 |
+
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
1374 |
+
|
1375 |
+
Classification scores (before SoftMax).
|
1376 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
1377 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1378 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1379 |
+
|
1380 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1381 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
1382 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1383 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1384 |
+
|
1385 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1386 |
+
heads.
|
1387 |
+
"""
|
1388 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
1389 |
+
|
1390 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
1391 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
1392 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
1393 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
1394 |
+
inputs_embeds = (
|
1395 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
1396 |
+
if inputs_embeds is not None
|
1397 |
+
else None
|
1398 |
+
)
|
1399 |
+
|
1400 |
+
outputs = self.bert(
|
1401 |
+
input_ids,
|
1402 |
+
attention_mask=attention_mask,
|
1403 |
+
token_type_ids=token_type_ids,
|
1404 |
+
position_ids=position_ids,
|
1405 |
+
head_mask=head_mask,
|
1406 |
+
inputs_embeds=inputs_embeds,
|
1407 |
+
output_attentions=output_attentions,
|
1408 |
+
output_hidden_states=output_hidden_states,
|
1409 |
+
)
|
1410 |
+
|
1411 |
+
pooled_output = outputs[1]
|
1412 |
+
|
1413 |
+
pooled_output = self.dropout(pooled_output)
|
1414 |
+
logits = self.classifier(pooled_output)
|
1415 |
+
reshaped_logits = logits.view(-1, num_choices)
|
1416 |
+
|
1417 |
+
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
1418 |
+
|
1419 |
+
if labels is not None:
|
1420 |
+
loss_fct = CrossEntropyLoss()
|
1421 |
+
loss = loss_fct(reshaped_logits, labels)
|
1422 |
+
outputs = (loss,) + outputs
|
1423 |
+
|
1424 |
+
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
1425 |
+
|
1426 |
+
|
1427 |
+
|
1428 |
+
|
1429 |
+
[DOCS]
|
1430 |
+
@add_start_docstrings(
|
1431 |
+
"""Bert Model with a token classification head on top (a linear layer on top of
|
1432 |
+
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
1433 |
+
BERT_START_DOCSTRING,
|
1434 |
+
)
|
1435 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
1436 |
+
def __init__(self, config):
|
1437 |
+
super().__init__(config)
|
1438 |
+
self.num_labels = config.num_labels
|
1439 |
+
|
1440 |
+
self.bert = BertModel(config)
|
1441 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1442 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1443 |
+
|
1444 |
+
self.init_weights()
|
1445 |
+
|
1446 |
+
|
1447 |
+
[DOCS]
|
1448 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
1449 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
1450 |
+
def forward(
|
1451 |
+
self,
|
1452 |
+
input_ids=None,
|
1453 |
+
attention_mask=None,
|
1454 |
+
token_type_ids=None,
|
1455 |
+
position_ids=None,
|
1456 |
+
head_mask=None,
|
1457 |
+
inputs_embeds=None,
|
1458 |
+
labels=None,
|
1459 |
+
output_attentions=None,
|
1460 |
+
output_hidden_states=None,
|
1461 |
+
):
|
1462 |
+
r"""
|
1463 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
1464 |
+
Labels for computing the token classification loss.
|
1465 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
1466 |
+
|
1467 |
+
Returns:
|
1468 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1469 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
1470 |
+
Classification loss.
|
1471 |
+
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
1472 |
+
Classification scores (before SoftMax).
|
1473 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
1474 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1475 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1476 |
+
|
1477 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1478 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
1479 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1480 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1481 |
+
|
1482 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1483 |
+
heads.
|
1484 |
+
"""
|
1485 |
+
|
1486 |
+
outputs = self.bert(
|
1487 |
+
input_ids,
|
1488 |
+
attention_mask=attention_mask,
|
1489 |
+
token_type_ids=token_type_ids,
|
1490 |
+
position_ids=position_ids,
|
1491 |
+
head_mask=head_mask,
|
1492 |
+
inputs_embeds=inputs_embeds,
|
1493 |
+
output_attentions=output_attentions,
|
1494 |
+
output_hidden_states=output_hidden_states,
|
1495 |
+
)
|
1496 |
+
|
1497 |
+
sequence_output = outputs[0]
|
1498 |
+
|
1499 |
+
sequence_output = self.dropout(sequence_output)
|
1500 |
+
logits = self.classifier(sequence_output)
|
1501 |
+
|
1502 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
1503 |
+
if labels is not None:
|
1504 |
+
loss_fct = CrossEntropyLoss()
|
1505 |
+
# Only keep active parts of the loss
|
1506 |
+
if attention_mask is not None:
|
1507 |
+
active_loss = attention_mask.view(-1) == 1
|
1508 |
+
active_logits = logits.view(-1, self.num_labels)
|
1509 |
+
active_labels = torch.where(
|
1510 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
1511 |
+
)
|
1512 |
+
loss = loss_fct(active_logits, active_labels)
|
1513 |
+
else:
|
1514 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1515 |
+
outputs = (loss,) + outputs
|
1516 |
+
|
1517 |
+
return outputs # (loss), scores, (hidden_states), (attentions)
|
1518 |
+
|
1519 |
+
|
1520 |
+
|
1521 |
+
|
1522 |
+
[DOCS]
|
1523 |
+
@add_start_docstrings(
|
1524 |
+
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
1525 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
1526 |
+
BERT_START_DOCSTRING,
|
1527 |
+
)
|
1528 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
1529 |
+
def __init__(self, config):
|
1530 |
+
super().__init__(config)
|
1531 |
+
self.num_labels = config.num_labels
|
1532 |
+
|
1533 |
+
self.bert = BertModel(config)
|
1534 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1535 |
+
|
1536 |
+
self.init_weights()
|
1537 |
+
|
1538 |
+
|
1539 |
+
[DOCS]
|
1540 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
1541 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
1542 |
+
def forward(
|
1543 |
+
self,
|
1544 |
+
input_ids=None,
|
1545 |
+
attention_mask=None,
|
1546 |
+
token_type_ids=None,
|
1547 |
+
position_ids=None,
|
1548 |
+
head_mask=None,
|
1549 |
+
inputs_embeds=None,
|
1550 |
+
start_positions=None,
|
1551 |
+
end_positions=None,
|
1552 |
+
output_attentions=None,
|
1553 |
+
output_hidden_states=None,
|
1554 |
+
):
|
1555 |
+
r"""
|
1556 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1557 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1558 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1559 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1560 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1561 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1562 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1563 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1564 |
+
|
1565 |
+
Returns:
|
1566 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1567 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
1568 |
+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
1569 |
+
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
1570 |
+
Span-start scores (before SoftMax).
|
1571 |
+
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
1572 |
+
Span-end scores (before SoftMax).
|
1573 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
1574 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1575 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1576 |
+
|
1577 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1578 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
1579 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1580 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1581 |
+
|
1582 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1583 |
+
heads.
|
1584 |
+
"""
|
1585 |
+
|
1586 |
+
outputs = self.bert(
|
1587 |
+
input_ids,
|
1588 |
+
attention_mask=attention_mask,
|
1589 |
+
token_type_ids=token_type_ids,
|
1590 |
+
position_ids=position_ids,
|
1591 |
+
head_mask=head_mask,
|
1592 |
+
inputs_embeds=inputs_embeds,
|
1593 |
+
output_attentions=output_attentions,
|
1594 |
+
output_hidden_states=output_hidden_states,
|
1595 |
+
)
|
1596 |
+
|
1597 |
+
sequence_output = outputs[0]
|
1598 |
+
|
1599 |
+
logits = self.qa_outputs(sequence_output)
|
1600 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1601 |
+
start_logits = start_logits.squeeze(-1)
|
1602 |
+
end_logits = end_logits.squeeze(-1)
|
1603 |
+
|
1604 |
+
outputs = (start_logits, end_logits,) + outputs[2:]
|
1605 |
+
if start_positions is not None and end_positions is not None:
|
1606 |
+
# If we are on multi-GPU, split add a dimension
|
1607 |
+
if len(start_positions.size()) > 1:
|
1608 |
+
start_positions = start_positions.squeeze(-1)
|
1609 |
+
if len(end_positions.size()) > 1:
|
1610 |
+
end_positions = end_positions.squeeze(-1)
|
1611 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1612 |
+
ignored_index = start_logits.size(1)
|
1613 |
+
start_positions.clamp_(0, ignored_index)
|
1614 |
+
end_positions.clamp_(0, ignored_index)
|
1615 |
+
|
1616 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1617 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1618 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1619 |
+
total_loss = (start_loss + end_loss) / 2
|
1620 |
+
outputs = (total_loss,) + outputs
|
1621 |
+
|
1622 |
+
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
src/reference_code/evaluate_embeddings.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
import numpy
|
5 |
+
|
6 |
+
import pickle
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from ..bert import BERT
|
10 |
+
from ..vocab import Vocab
|
11 |
+
from ..dataset import TokenizerDataset
|
12 |
+
import argparse
|
13 |
+
from itertools import combinations
|
14 |
+
|
15 |
+
def generate_subset(s):
|
16 |
+
subsets = []
|
17 |
+
for r in range(len(s) + 1):
|
18 |
+
combinations_result = combinations(s, r)
|
19 |
+
if r==1:
|
20 |
+
subsets.extend(([item] for sublist in combinations_result for item in sublist))
|
21 |
+
else:
|
22 |
+
subsets.extend((list(sublist) for sublist in combinations_result))
|
23 |
+
subsets_dict = {i:s for i, s in enumerate(subsets)}
|
24 |
+
return subsets_dict
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
parser.add_argument('-workspace_name', type=str, default=None)
|
30 |
+
parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length")
|
31 |
+
parser.add_argument('-pretrain', type=bool, default=False)
|
32 |
+
parser.add_argument('-masked_pred', type=bool, default=False)
|
33 |
+
parser.add_argument('-epoch', type=str, default=None)
|
34 |
+
# parser.add_argument('-set_label', type=bool, default=False)
|
35 |
+
# parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks')
|
36 |
+
|
37 |
+
options = parser.parse_args()
|
38 |
+
|
39 |
+
folder_path = options.workspace_name+"/" if options.workspace_name else ""
|
40 |
+
|
41 |
+
# if options.set_label:
|
42 |
+
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'})
|
43 |
+
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
|
44 |
+
# else:
|
45 |
+
# label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb"))
|
46 |
+
# print(f"options.label _standard: {options.label_standard}")
|
47 |
+
vocab_path = f"{folder_path}check/pretraining/vocab.txt"
|
48 |
+
# vocab_path = f"{folder_path}pretraining/vocab.txt"
|
49 |
+
|
50 |
+
|
51 |
+
print("Loading Vocab", vocab_path)
|
52 |
+
vocab_obj = Vocab(vocab_path)
|
53 |
+
vocab_obj.load_vocab()
|
54 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
55 |
+
|
56 |
+
# label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb")))
|
57 |
+
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'})
|
58 |
+
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
|
59 |
+
|
60 |
+
if options.masked_pred:
|
61 |
+
str_code = "masked_prediction"
|
62 |
+
output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}"
|
63 |
+
else:
|
64 |
+
str_code = "masked"
|
65 |
+
output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}"
|
66 |
+
|
67 |
+
folder_path = folder_path+"check/"
|
68 |
+
# folder_path = folder_path
|
69 |
+
if options.pretrain:
|
70 |
+
pretrain_file = f"{folder_path}pretraining/pretrain.txt"
|
71 |
+
pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl"
|
72 |
+
|
73 |
+
# pretrain_file = f"{folder_path}finetuning/train.txt"
|
74 |
+
# pretrain_label = f"{folder_path}finetuning/train_label.txt"
|
75 |
+
|
76 |
+
embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl"
|
77 |
+
print("Loading Pretrain Dataset ", pretrain_file)
|
78 |
+
pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len)
|
79 |
+
|
80 |
+
print("Creating Dataloader")
|
81 |
+
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4)
|
82 |
+
else:
|
83 |
+
val_file = f"{folder_path}pretraining/test.txt"
|
84 |
+
val_label = f"{folder_path}pretraining/test_opt.txt"
|
85 |
+
|
86 |
+
# val_file = f"{folder_path}finetuning/test.txt"
|
87 |
+
# val_label = f"{folder_path}finetuning/test_label.txt"
|
88 |
+
embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl"
|
89 |
+
|
90 |
+
print("Loading Validation Dataset ", val_file)
|
91 |
+
val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len)
|
92 |
+
|
93 |
+
print("Creating Dataloader")
|
94 |
+
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)
|
95 |
+
|
96 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
97 |
+
print(device)
|
98 |
+
print("Load Pre-trained BERT model...")
|
99 |
+
print(output_name)
|
100 |
+
bert = torch.load(output_name, map_location=device)
|
101 |
+
# learned_parameters = model_ep0.state_dict()
|
102 |
+
for param in bert.parameters():
|
103 |
+
param.requires_grad = False
|
104 |
+
|
105 |
+
if options.pretrain:
|
106 |
+
print("Pretrain-embeddings....")
|
107 |
+
data_iter = tqdm.tqdm(enumerate(pretrain_data_loader),
|
108 |
+
desc="pre-train",
|
109 |
+
total=len(pretrain_data_loader),
|
110 |
+
bar_format="{l_bar}{r_bar}")
|
111 |
+
pretrain_embeddings = []
|
112 |
+
for i, data in data_iter:
|
113 |
+
data = {key: value.to(device) for key, value in data.items()}
|
114 |
+
hrep = bert(data["bert_input"], data["segment_label"])
|
115 |
+
# print(hrep[:,0].cpu().detach().numpy())
|
116 |
+
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
|
117 |
+
pretrain_embeddings.extend(embeddings)
|
118 |
+
pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb"))
|
119 |
+
# pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb"))
|
120 |
+
|
121 |
+
else:
|
122 |
+
print("Validation-embeddings....")
|
123 |
+
data_iter = tqdm.tqdm(enumerate(val_data_loader),
|
124 |
+
desc="validation",
|
125 |
+
total=len(val_data_loader),
|
126 |
+
bar_format="{l_bar}{r_bar}")
|
127 |
+
val_embeddings = []
|
128 |
+
for i, data in data_iter:
|
129 |
+
data = {key: value.to(device) for key, value in data.items()}
|
130 |
+
hrep = bert(data["bert_input"], data["segment_label"])
|
131 |
+
# print(,hrep[:,0].shape)
|
132 |
+
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
|
133 |
+
val_embeddings.extend(embeddings)
|
134 |
+
pickle.dump(val_embeddings, open(embedding_file_path,"wb"))
|
135 |
+
# pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb"))
|
136 |
+
|
src/reference_code/metrics.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.special import softmax
|
3 |
+
|
4 |
+
|
5 |
+
class CELoss(object):
|
6 |
+
|
7 |
+
def compute_bin_boundaries(self, probabilities = np.array([])):
|
8 |
+
|
9 |
+
#uniform bin spacing
|
10 |
+
if probabilities.size == 0:
|
11 |
+
bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
|
12 |
+
self.bin_lowers = bin_boundaries[:-1]
|
13 |
+
self.bin_uppers = bin_boundaries[1:]
|
14 |
+
else:
|
15 |
+
#size of bins
|
16 |
+
bin_n = int(self.n_data/self.n_bins)
|
17 |
+
|
18 |
+
bin_boundaries = np.array([])
|
19 |
+
|
20 |
+
probabilities_sort = np.sort(probabilities)
|
21 |
+
|
22 |
+
for i in range(0,self.n_bins):
|
23 |
+
bin_boundaries = np.append(bin_boundaries,probabilities_sort[i*bin_n])
|
24 |
+
bin_boundaries = np.append(bin_boundaries,1.0)
|
25 |
+
|
26 |
+
self.bin_lowers = bin_boundaries[:-1]
|
27 |
+
self.bin_uppers = bin_boundaries[1:]
|
28 |
+
|
29 |
+
|
30 |
+
def get_probabilities(self, output, labels, logits):
|
31 |
+
#If not probabilities apply softmax!
|
32 |
+
if logits:
|
33 |
+
self.probabilities = softmax(output, axis=1)
|
34 |
+
else:
|
35 |
+
self.probabilities = output
|
36 |
+
|
37 |
+
self.labels = np.argmax(labels, axis=1)
|
38 |
+
self.confidences = np.max(self.probabilities, axis=1)
|
39 |
+
self.predictions = np.argmax(self.probabilities, axis=1)
|
40 |
+
self.accuracies = np.equal(self.predictions, self.labels)
|
41 |
+
|
42 |
+
def binary_matrices(self):
|
43 |
+
idx = np.arange(self.n_data)
|
44 |
+
#make matrices of zeros
|
45 |
+
pred_matrix = np.zeros([self.n_data,self.n_class])
|
46 |
+
label_matrix = np.zeros([self.n_data,self.n_class])
|
47 |
+
#self.acc_matrix = np.zeros([self.n_data,self.n_class])
|
48 |
+
pred_matrix[idx,self.predictions] = 1
|
49 |
+
label_matrix[idx,self.labels] = 1
|
50 |
+
|
51 |
+
self.acc_matrix = np.equal(pred_matrix, label_matrix)
|
52 |
+
|
53 |
+
|
54 |
+
def compute_bins(self, index = None):
|
55 |
+
self.bin_prop = np.zeros(self.n_bins)
|
56 |
+
self.bin_acc = np.zeros(self.n_bins)
|
57 |
+
self.bin_conf = np.zeros(self.n_bins)
|
58 |
+
self.bin_score = np.zeros(self.n_bins)
|
59 |
+
|
60 |
+
if index == None:
|
61 |
+
confidences = self.confidences
|
62 |
+
accuracies = self.accuracies
|
63 |
+
else:
|
64 |
+
confidences = self.probabilities[:,index]
|
65 |
+
accuracies = self.acc_matrix[:,index]
|
66 |
+
|
67 |
+
|
68 |
+
for i, (bin_lower, bin_upper) in enumerate(zip(self.bin_lowers, self.bin_uppers)):
|
69 |
+
# Calculated |confidence - accuracy| in each bin
|
70 |
+
in_bin = np.greater(confidences,bin_lower.item()) * np.less_equal(confidences,bin_upper.item())
|
71 |
+
self.bin_prop[i] = np.mean(in_bin)
|
72 |
+
|
73 |
+
if self.bin_prop[i].item() > 0:
|
74 |
+
self.bin_acc[i] = np.mean(accuracies[in_bin])
|
75 |
+
self.bin_conf[i] = np.mean(confidences[in_bin])
|
76 |
+
self.bin_score[i] = np.abs(self.bin_conf[i] - self.bin_acc[i])
|
77 |
+
|
78 |
+
class MaxProbCELoss(CELoss):
|
79 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
80 |
+
self.n_bins = n_bins
|
81 |
+
super().compute_bin_boundaries()
|
82 |
+
super().get_probabilities(output, labels, logits)
|
83 |
+
super().compute_bins()
|
84 |
+
|
85 |
+
#http://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf
|
86 |
+
class ECELoss(MaxProbCELoss):
|
87 |
+
|
88 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
89 |
+
super().loss(output, labels, n_bins, logits)
|
90 |
+
return np.dot(self.bin_prop,self.bin_score)
|
91 |
+
|
92 |
+
class MCELoss(MaxProbCELoss):
|
93 |
+
|
94 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
95 |
+
super().loss(output, labels, n_bins, logits)
|
96 |
+
return np.max(self.bin_score)
|
97 |
+
|
98 |
+
#https://arxiv.org/abs/1905.11001
|
99 |
+
#Overconfidence Loss (Good in high risk applications where confident but wrong predictions can be especially harmful)
|
100 |
+
class OELoss(MaxProbCELoss):
|
101 |
+
|
102 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
103 |
+
super().loss(output, labels, n_bins, logits)
|
104 |
+
return np.dot(self.bin_prop,self.bin_conf * np.maximum(self.bin_conf-self.bin_acc,np.zeros(self.n_bins)))
|
105 |
+
|
106 |
+
|
107 |
+
#https://arxiv.org/abs/1904.01685
|
108 |
+
class SCELoss(CELoss):
|
109 |
+
|
110 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
111 |
+
sce = 0.0
|
112 |
+
self.n_bins = n_bins
|
113 |
+
self.n_data = len(output)
|
114 |
+
self.n_class = len(output[0])
|
115 |
+
|
116 |
+
super().compute_bin_boundaries()
|
117 |
+
super().get_probabilities(output, labels, logits)
|
118 |
+
super().binary_matrices()
|
119 |
+
|
120 |
+
for i in range(self.n_class):
|
121 |
+
super().compute_bins(i)
|
122 |
+
sce += np.dot(self.bin_prop,self.bin_score)
|
123 |
+
|
124 |
+
return sce/self.n_class
|
125 |
+
|
126 |
+
class TACELoss(CELoss):
|
127 |
+
|
128 |
+
def loss(self, output, labels, threshold = 0.01, n_bins = 15, logits = True):
|
129 |
+
tace = 0.0
|
130 |
+
self.n_bins = n_bins
|
131 |
+
self.n_data = len(output)
|
132 |
+
self.n_class = len(output[0])
|
133 |
+
|
134 |
+
super().get_probabilities(output, labels, logits)
|
135 |
+
self.probabilities[self.probabilities < threshold] = 0
|
136 |
+
super().binary_matrices()
|
137 |
+
|
138 |
+
for i in range(self.n_class):
|
139 |
+
super().compute_bin_boundaries(self.probabilities[:,i])
|
140 |
+
super().compute_bins(i)
|
141 |
+
tace += np.dot(self.bin_prop,self.bin_score)
|
142 |
+
|
143 |
+
return tace/self.n_class
|
144 |
+
|
145 |
+
#create TACELoss with threshold fixed at 0
|
146 |
+
class ACELoss(TACELoss):
|
147 |
+
|
148 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
149 |
+
return super().loss(output, labels, 0.0 , n_bins, logits)
|
src/reference_code/pretrainer-old.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.optim import Adam, SGD
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
from ..bert import BERT
|
9 |
+
from ..seq_model import BERTSM
|
10 |
+
from ..classifier_model import BERTForClassification
|
11 |
+
from ..optim_schedule import ScheduledOptim
|
12 |
+
|
13 |
+
import tqdm
|
14 |
+
import sys
|
15 |
+
import time
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
# import visualization
|
19 |
+
|
20 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
21 |
+
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
import seaborn as sns
|
24 |
+
import pandas as pd
|
25 |
+
from collections import defaultdict
|
26 |
+
import os
|
27 |
+
|
28 |
+
class ECE(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self, n_bins=15):
|
31 |
+
"""
|
32 |
+
n_bins (int): number of confidence interval bins
|
33 |
+
"""
|
34 |
+
super(ECE, self).__init__()
|
35 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
36 |
+
self.bin_lowers = bin_boundaries[:-1]
|
37 |
+
self.bin_uppers = bin_boundaries[1:]
|
38 |
+
|
39 |
+
def forward(self, logits, labels):
|
40 |
+
softmaxes = F.softmax(logits, dim=1)
|
41 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
42 |
+
labels = torch.argmax(labels,1)
|
43 |
+
accuracies = predictions.eq(labels)
|
44 |
+
|
45 |
+
ece = torch.zeros(1, device=logits.device)
|
46 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
47 |
+
# Calculated |confidence - accuracy| in each bin
|
48 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
49 |
+
prop_in_bin = in_bin.float().mean()
|
50 |
+
if prop_in_bin.item() > 0:
|
51 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
52 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
53 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
54 |
+
|
55 |
+
return ece
|
56 |
+
|
57 |
+
def accurate_nb(preds, labels):
|
58 |
+
pred_flat = np.argmax(preds, axis=1).flatten()
|
59 |
+
labels_flat = np.argmax(labels, axis=1).flatten()
|
60 |
+
labels_flat = labels.flatten()
|
61 |
+
return np.sum(pred_flat == labels_flat)
|
62 |
+
|
63 |
+
class BERTTrainer:
|
64 |
+
"""
|
65 |
+
BERTTrainer pretrains BERT model on input sequence of strategies.
|
66 |
+
BERTTrainer make the pretrained BERT model with one training method objective.
|
67 |
+
1. Masked Strategy Modelling : 3.3.1 Task #1: Masked SM
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
71 |
+
train_dataloader: DataLoader, val_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
|
72 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=5000,
|
73 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
|
74 |
+
workspace_name=None, code=None):
|
75 |
+
"""
|
76 |
+
:param bert: BERT model which you want to train
|
77 |
+
:param vocab_size: total word vocab size
|
78 |
+
:param train_dataloader: train dataset data loader
|
79 |
+
:param test_dataloader: test dataset data loader [can be None]
|
80 |
+
:param lr: learning rate of optimizer
|
81 |
+
:param betas: Adam optimizer betas
|
82 |
+
:param weight_decay: Adam optimizer weight decay param
|
83 |
+
:param with_cuda: traning with cuda
|
84 |
+
:param log_freq: logging frequency of the batch iteration
|
85 |
+
"""
|
86 |
+
|
87 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
88 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
89 |
+
print(cuda_condition, " Device used = ", self.device)
|
90 |
+
|
91 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
92 |
+
|
93 |
+
# This BERT model will be saved every epoch
|
94 |
+
self.bert = bert.to(self.device)
|
95 |
+
# Initialize the BERT Language Model, with BERT model
|
96 |
+
self.model = BERTSM(bert, vocab_size).to(self.device)
|
97 |
+
|
98 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
99 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
100 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
101 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
102 |
+
|
103 |
+
# Setting the train and test data loader
|
104 |
+
self.train_data = train_dataloader
|
105 |
+
self.val_data = val_dataloader
|
106 |
+
self.test_data = test_dataloader
|
107 |
+
|
108 |
+
# Setting the Adam optimizer with hyper-param
|
109 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
110 |
+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
|
111 |
+
|
112 |
+
# Using Negative Log Likelihood Loss function for predicting the masked_token
|
113 |
+
self.criterion = nn.NLLLoss(ignore_index=0)
|
114 |
+
|
115 |
+
self.log_freq = log_freq
|
116 |
+
self.same_student_prediction = same_student_prediction
|
117 |
+
self.workspace_name = workspace_name
|
118 |
+
self.save_model = False
|
119 |
+
self.code = code
|
120 |
+
self.avg_loss = 10000
|
121 |
+
self.start_time = time.time()
|
122 |
+
|
123 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
124 |
+
|
125 |
+
def train(self, epoch):
|
126 |
+
self.iteration(epoch, self.train_data)
|
127 |
+
|
128 |
+
def val(self, epoch):
|
129 |
+
self.iteration(epoch, self.val_data, phase="val")
|
130 |
+
|
131 |
+
def test(self, epoch):
|
132 |
+
self.iteration(epoch, self.test_data, phase="test")
|
133 |
+
|
134 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
135 |
+
"""
|
136 |
+
loop over the data_loader for training or testing
|
137 |
+
if on train status, backward operation is activated
|
138 |
+
and also auto save the model every peoch
|
139 |
+
|
140 |
+
:param epoch: current epoch index
|
141 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
142 |
+
:param train: boolean value of is train or test
|
143 |
+
:return: None
|
144 |
+
"""
|
145 |
+
# str_code = "train" if train else "test"
|
146 |
+
# code = "masked_prediction" if self.same_student_prediction else "masked"
|
147 |
+
|
148 |
+
self.log_file = f"{self.workspace_name}/logs/{self.code}/log_{phase}_pretrained.txt"
|
149 |
+
# bert_hidden_representations = []
|
150 |
+
if epoch == 0:
|
151 |
+
f = open(self.log_file, 'w')
|
152 |
+
f.close()
|
153 |
+
if phase == "val":
|
154 |
+
self.avg_loss = 10000
|
155 |
+
# Setting the tqdm progress bar
|
156 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
157 |
+
desc="EP_%s:%d" % (phase, epoch),
|
158 |
+
total=len(data_loader),
|
159 |
+
bar_format="{l_bar}{r_bar}")
|
160 |
+
|
161 |
+
avg_loss_mask = 0.0
|
162 |
+
total_correct_mask = 0
|
163 |
+
total_element_mask = 0
|
164 |
+
|
165 |
+
avg_loss_pred = 0.0
|
166 |
+
total_correct_pred = 0
|
167 |
+
total_element_pred = 0
|
168 |
+
|
169 |
+
avg_loss = 0.0
|
170 |
+
|
171 |
+
if phase == "train":
|
172 |
+
self.model.train()
|
173 |
+
else:
|
174 |
+
self.model.eval()
|
175 |
+
with open(self.log_file, 'a') as f:
|
176 |
+
sys.stdout = f
|
177 |
+
for i, data in data_iter:
|
178 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
179 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
180 |
+
# if i == 0:
|
181 |
+
# print(f"data : {data[0]}")
|
182 |
+
# 1. forward the next_sentence_prediction and masked_lm model
|
183 |
+
# next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
|
184 |
+
if self.same_student_prediction:
|
185 |
+
bert_hidden_rep, mask_lm_output, same_student_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction)
|
186 |
+
else:
|
187 |
+
bert_hidden_rep, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction)
|
188 |
+
|
189 |
+
# embeddings = [h for h in bert_hidden_rep.cpu().detach().numpy()]
|
190 |
+
# bert_hidden_representations.extend(embeddings)
|
191 |
+
|
192 |
+
|
193 |
+
# 2-2. NLLLoss of predicting masked token word
|
194 |
+
mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
|
195 |
+
|
196 |
+
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
|
197 |
+
if self.same_student_prediction:
|
198 |
+
# 2-1. NLL(negative log likelihood) loss of is_next classification result
|
199 |
+
same_student_loss = self.criterion(same_student_output, data["is_same_student"])
|
200 |
+
loss = same_student_loss + mask_loss
|
201 |
+
else:
|
202 |
+
loss = mask_loss
|
203 |
+
|
204 |
+
# 3. backward and optimization only in train
|
205 |
+
if phase == "train":
|
206 |
+
self.optim_schedule.zero_grad()
|
207 |
+
loss.backward()
|
208 |
+
self.optim_schedule.step_and_update_lr()
|
209 |
+
|
210 |
+
|
211 |
+
# print(f"mask_lm_output : {mask_lm_output}")
|
212 |
+
# non_zero_mask = (data["bert_label"] != 0).float()
|
213 |
+
# print(f"bert_label : {data['bert_label']}")
|
214 |
+
non_zero_mask = (data["bert_label"] != 0).float()
|
215 |
+
predictions = torch.argmax(mask_lm_output, dim=-1)
|
216 |
+
# print(f"predictions : {predictions}")
|
217 |
+
predicted_masked = predictions*non_zero_mask
|
218 |
+
# print(f"predicted_masked : {predicted_masked}")
|
219 |
+
mask_correct = ((data["bert_label"] == predicted_masked)*non_zero_mask).sum().item()
|
220 |
+
# print(f"mask_correct : {mask_correct}")
|
221 |
+
# print(f"non_zero_mask.sum().item() : {non_zero_mask.sum().item()}")
|
222 |
+
|
223 |
+
avg_loss_mask += loss.item()
|
224 |
+
total_correct_mask += mask_correct
|
225 |
+
total_element_mask += non_zero_mask.sum().item()
|
226 |
+
# total_element_mask += data["bert_label"].sum().item()
|
227 |
+
|
228 |
+
torch.cuda.empty_cache()
|
229 |
+
post_fix = {
|
230 |
+
"epoch": epoch,
|
231 |
+
"iter": i,
|
232 |
+
"avg_loss": avg_loss_mask / (i + 1),
|
233 |
+
"avg_acc_mask": (total_correct_mask / total_element_mask * 100) if total_element_mask != 0 else 0,
|
234 |
+
"loss": loss.item()
|
235 |
+
}
|
236 |
+
|
237 |
+
# next sentence prediction accuracy
|
238 |
+
if self.same_student_prediction:
|
239 |
+
correct = same_student_output.argmax(dim=-1).eq(data["is_same_student"]).sum().item()
|
240 |
+
avg_loss_pred += loss.item()
|
241 |
+
total_correct_pred += correct
|
242 |
+
total_element_pred += data["is_same_student"].nelement()
|
243 |
+
# correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
|
244 |
+
post_fix["avg_loss"] = avg_loss_pred / (i + 1)
|
245 |
+
post_fix["avg_acc_pred"] = total_correct_pred / total_element_pred * 100
|
246 |
+
post_fix["loss"] = loss.item()
|
247 |
+
|
248 |
+
avg_loss +=loss.item()
|
249 |
+
|
250 |
+
if i % self.log_freq == 0:
|
251 |
+
data_iter.write(str(post_fix))
|
252 |
+
# if not train and epoch > 20 :
|
253 |
+
# pickle.dump(mask_lm_output.cpu().detach().numpy(), open(f"logs/mask/mask_out_e{epoch}_{i}.pkl","wb"))
|
254 |
+
# pickle.dump(data["bert_label"].cpu().detach().numpy(), open(f"logs/mask/label_e{epoch}_{i}.pkl","wb"))
|
255 |
+
end_time = time.time()
|
256 |
+
final_msg = {
|
257 |
+
"epoch": f"EP{epoch}_{phase}",
|
258 |
+
"avg_loss": avg_loss / len(data_iter),
|
259 |
+
"total_masked_acc": total_correct_mask * 100.0 / total_element_mask if total_element_mask != 0 else 0,
|
260 |
+
"time_taken_from_start": end_time - self.start_time
|
261 |
+
}
|
262 |
+
|
263 |
+
if self.same_student_prediction:
|
264 |
+
final_msg["total_prediction_acc"] = total_correct_pred * 100.0 / total_element_pred
|
265 |
+
|
266 |
+
print(final_msg)
|
267 |
+
|
268 |
+
f.close()
|
269 |
+
sys.stdout = sys.__stdout__
|
270 |
+
|
271 |
+
if phase == "val":
|
272 |
+
self.save_model = False
|
273 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
274 |
+
self.save_model = True
|
275 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
276 |
+
|
277 |
+
# pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
def save(self, epoch, file_path="output/bert_trained.model"):
|
282 |
+
"""
|
283 |
+
Saving the current BERT model on file_path
|
284 |
+
|
285 |
+
:param epoch: current epoch number
|
286 |
+
:param file_path: model output path which gonna be file_path+"ep%d" % epoch
|
287 |
+
:return: final_output_path
|
288 |
+
"""
|
289 |
+
# if self.code:
|
290 |
+
# fpath = file_path.split("/")
|
291 |
+
# # output_path = fpath[0]+ "/"+ fpath[1]+f"/{self.code}/" + fpath[2] + ".ep%d" % epoch
|
292 |
+
# output_path = "/",join(fpath[0]+ "/"+ fpath[1]+f"/{self.code}/" + fpath[-1] + ".ep%d" % epoch
|
293 |
+
|
294 |
+
# else:
|
295 |
+
output_path = file_path + ".ep%d" % epoch
|
296 |
+
|
297 |
+
torch.save(self.bert.cpu(), output_path)
|
298 |
+
self.bert.to(self.device)
|
299 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
300 |
+
return output_path
|
301 |
+
|
302 |
+
|
303 |
+
class BERTFineTuneTrainer:
|
304 |
+
|
305 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
306 |
+
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
307 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
308 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
309 |
+
num_labels=2, finetune_task=""):
|
310 |
+
"""
|
311 |
+
:param bert: BERT model which you want to train
|
312 |
+
:param vocab_size: total word vocab size
|
313 |
+
:param train_dataloader: train dataset data loader
|
314 |
+
:param test_dataloader: test dataset data loader [can be None]
|
315 |
+
:param lr: learning rate of optimizer
|
316 |
+
:param betas: Adam optimizer betas
|
317 |
+
:param weight_decay: Adam optimizer weight decay param
|
318 |
+
:param with_cuda: traning with cuda
|
319 |
+
:param log_freq: logging frequency of the batch iteration
|
320 |
+
"""
|
321 |
+
|
322 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
323 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
324 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
325 |
+
print(with_cuda, cuda_condition, " Device used = ", self.device)
|
326 |
+
|
327 |
+
# This BERT model will be saved every epoch
|
328 |
+
self.bert = bert
|
329 |
+
for param in self.bert.parameters():
|
330 |
+
param.requires_grad = False
|
331 |
+
# Initialize the BERT Language Model, with BERT model
|
332 |
+
self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
333 |
+
|
334 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
335 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
336 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
337 |
+
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
|
338 |
+
|
339 |
+
# Setting the train and test data loader
|
340 |
+
self.train_data = train_dataloader
|
341 |
+
self.test_data = test_dataloader
|
342 |
+
|
343 |
+
self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
344 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
345 |
+
|
346 |
+
if num_labels == 1:
|
347 |
+
self.criterion = nn.MSELoss()
|
348 |
+
elif num_labels == 2:
|
349 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
350 |
+
# self.criterion = nn.CrossEntropyLoss()
|
351 |
+
elif num_labels > 2:
|
352 |
+
self.criterion = nn.CrossEntropyLoss()
|
353 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
354 |
+
|
355 |
+
# self.ece_criterion = ECE().to(self.device)
|
356 |
+
|
357 |
+
self.log_freq = log_freq
|
358 |
+
self.workspace_name = workspace_name
|
359 |
+
self.finetune_task = finetune_task
|
360 |
+
self.save_model = False
|
361 |
+
self.avg_loss = 10000
|
362 |
+
self.start_time = time.time()
|
363 |
+
self.probability_list = []
|
364 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
365 |
+
|
366 |
+
def train(self, epoch):
|
367 |
+
self.iteration(epoch, self.train_data)
|
368 |
+
|
369 |
+
def test(self, epoch):
|
370 |
+
self.iteration(epoch, self.test_data, train=False)
|
371 |
+
|
372 |
+
def iteration(self, epoch, data_loader, train=True):
|
373 |
+
"""
|
374 |
+
loop over the data_loader for training or testing
|
375 |
+
if on train status, backward operation is activated
|
376 |
+
and also auto save the model every peoch
|
377 |
+
|
378 |
+
:param epoch: current epoch index
|
379 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
380 |
+
:param train: boolean value of is train or test
|
381 |
+
:return: None
|
382 |
+
"""
|
383 |
+
str_code = "train" if train else "test"
|
384 |
+
|
385 |
+
self.log_file = f"{self.workspace_name}/logs/{self.finetune_task}/log_{str_code}_finetuned.txt"
|
386 |
+
|
387 |
+
if epoch == 0:
|
388 |
+
f = open(self.log_file, 'w')
|
389 |
+
f.close()
|
390 |
+
if not train:
|
391 |
+
self.avg_loss = 10000
|
392 |
+
|
393 |
+
# Setting the tqdm progress bar
|
394 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
395 |
+
desc="EP_%s:%d" % (str_code, epoch),
|
396 |
+
total=len(data_loader),
|
397 |
+
bar_format="{l_bar}{r_bar}")
|
398 |
+
|
399 |
+
avg_loss = 0.0
|
400 |
+
total_correct = 0
|
401 |
+
total_element = 0
|
402 |
+
plabels = []
|
403 |
+
tlabels = []
|
404 |
+
|
405 |
+
eval_accurate_nb = 0
|
406 |
+
nb_eval_examples = 0
|
407 |
+
logits_list = []
|
408 |
+
labels_list = []
|
409 |
+
|
410 |
+
if train:
|
411 |
+
self.model.train()
|
412 |
+
else:
|
413 |
+
self.model.eval()
|
414 |
+
self.probability_list = []
|
415 |
+
with open(self.log_file, 'a') as f:
|
416 |
+
sys.stdout = f
|
417 |
+
|
418 |
+
for i, data in data_iter:
|
419 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
420 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
421 |
+
if train:
|
422 |
+
h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"])
|
423 |
+
else:
|
424 |
+
with torch.no_grad():
|
425 |
+
h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"])
|
426 |
+
# print(logits, logits.shape)
|
427 |
+
logits_list.append(logits.cpu())
|
428 |
+
labels_list.append(data["progress_status"].cpu())
|
429 |
+
# print(">>>>>>>>>>>>", progress_output)
|
430 |
+
# print(f"{epoch}---nelement--- {data['progress_status'].nelement()}")
|
431 |
+
# print(data["progress_status"].shape, logits.shape)
|
432 |
+
progress_loss = self.criterion(logits, data["progress_status"])
|
433 |
+
loss = progress_loss
|
434 |
+
|
435 |
+
if torch.cuda.device_count() > 1:
|
436 |
+
loss = loss.mean()
|
437 |
+
|
438 |
+
# 3. backward and optimization only in train
|
439 |
+
if train:
|
440 |
+
self.optim.zero_grad()
|
441 |
+
loss.backward()
|
442 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
443 |
+
self.optim.step()
|
444 |
+
|
445 |
+
# progress prediction accuracy
|
446 |
+
# correct = progress_output.argmax(dim=-1).eq(data["progress_status"]).sum().item()
|
447 |
+
probs = nn.LogSoftmax(dim=-1)(logits)
|
448 |
+
self.probability_list.append(probs)
|
449 |
+
predicted_labels = torch.argmax(probs, dim=-1)
|
450 |
+
true_labels = torch.argmax(data["progress_status"], dim=-1)
|
451 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
452 |
+
tlabels.extend(true_labels.cpu().numpy())
|
453 |
+
|
454 |
+
# Compare predicted labels to true labels and calculate accuracy
|
455 |
+
correct = (predicted_labels == true_labels).sum().item()
|
456 |
+
avg_loss += loss.item()
|
457 |
+
total_correct += correct
|
458 |
+
# total_element += true_labels.nelement()
|
459 |
+
total_element += data["progress_status"].nelement()
|
460 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
461 |
+
|
462 |
+
# if train:
|
463 |
+
post_fix = {
|
464 |
+
"epoch": epoch,
|
465 |
+
"iter": i,
|
466 |
+
"avg_loss": avg_loss / (i + 1),
|
467 |
+
"avg_acc": total_correct / total_element * 100,
|
468 |
+
"loss": loss.item()
|
469 |
+
}
|
470 |
+
# else:
|
471 |
+
# logits = logits.detach().cpu().numpy()
|
472 |
+
# label_ids = data["progress_status"].to('cpu').numpy()
|
473 |
+
# tmp_eval_nb = accurate_nb(logits, label_ids)
|
474 |
+
|
475 |
+
# eval_accurate_nb += tmp_eval_nb
|
476 |
+
# nb_eval_examples += label_ids.shape[0]
|
477 |
+
|
478 |
+
# # total_element += data["progress_status"].nelement()
|
479 |
+
# # avg_loss += loss.item()
|
480 |
+
|
481 |
+
# post_fix = {
|
482 |
+
# "epoch": epoch,
|
483 |
+
# "iter": i,
|
484 |
+
# "avg_loss": avg_loss / (i + 1),
|
485 |
+
# "avg_acc": tmp_eval_nb / total_element * 100,
|
486 |
+
# "loss": loss.item()
|
487 |
+
# }
|
488 |
+
|
489 |
+
|
490 |
+
if i % self.log_freq == 0:
|
491 |
+
data_iter.write(str(post_fix))
|
492 |
+
|
493 |
+
# precisions = precision_score(plabels, tlabels, average="weighted")
|
494 |
+
# recalls = recall_score(plabels, tlabels, average="weighted")
|
495 |
+
f1_scores = f1_score(plabels, tlabels, average="weighted")
|
496 |
+
# if train:
|
497 |
+
end_time = time.time()
|
498 |
+
final_msg = {
|
499 |
+
"epoch": f"EP{epoch}_{str_code}",
|
500 |
+
"avg_loss": avg_loss / len(data_iter),
|
501 |
+
"total_acc": total_correct * 100.0 / total_element,
|
502 |
+
# "precisions": precisions,
|
503 |
+
# "recalls": recalls,
|
504 |
+
"f1_scores": f1_scores,
|
505 |
+
"time_taken_from_start": end_time - self.start_time
|
506 |
+
}
|
507 |
+
# else:
|
508 |
+
# eval_accuracy = eval_accurate_nb/nb_eval_examples
|
509 |
+
|
510 |
+
# logits_ece = torch.cat(logits_list)
|
511 |
+
# labels_ece = torch.cat(labels_list)
|
512 |
+
# ece = self.ece_criterion(logits_ece, labels_ece).item()
|
513 |
+
# end_time = time.time()
|
514 |
+
# final_msg = {
|
515 |
+
# "epoch": f"EP{epoch}_{str_code}",
|
516 |
+
# "eval_accuracy": eval_accuracy,
|
517 |
+
# "ece": ece,
|
518 |
+
# "avg_loss": avg_loss / len(data_iter),
|
519 |
+
# "precisions": precisions,
|
520 |
+
# "recalls": recalls,
|
521 |
+
# "f1_scores": f1_scores,
|
522 |
+
# "time_taken_from_start": end_time - self.start_time
|
523 |
+
# }
|
524 |
+
# if self.save_model:
|
525 |
+
# conf_hist = visualization.ConfidenceHistogram()
|
526 |
+
# plt_test = conf_hist.plot(np.array(logits_ece), np.array(labels_ece), title= f"Confidence Histogram {epoch}")
|
527 |
+
# plt_test.savefig(f"{self.workspace_name}/plots/confidence_histogram/{self.finetune_task}/conf_histogram_test_{epoch}.png",bbox_inches='tight')
|
528 |
+
# plt_test.close()
|
529 |
+
|
530 |
+
# rel_diagram = visualization.ReliabilityDiagram()
|
531 |
+
# plt_test_2 = rel_diagram.plot(np.array(logits_ece), np.array(labels_ece),title=f"Reliability Diagram {epoch}")
|
532 |
+
# plt_test_2.savefig(f"{self.workspace_name}/plots/confidence_histogram/{self.finetune_task}/rel_diagram_test_{epoch}.png",bbox_inches='tight')
|
533 |
+
# plt_test_2.close()
|
534 |
+
print(final_msg)
|
535 |
+
|
536 |
+
# print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element)
|
537 |
+
f.close()
|
538 |
+
sys.stdout = sys.__stdout__
|
539 |
+
self.save_model = False
|
540 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
541 |
+
self.save_model = True
|
542 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
543 |
+
|
544 |
+
def iteration_1(self, epoch_idx, data):
|
545 |
+
try:
|
546 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
547 |
+
logits = self.model(data['input_ids'], data['segment_label'])
|
548 |
+
# Ensure logits is a tensor, not a tuple
|
549 |
+
loss_fct = nn.CrossEntropyLoss()
|
550 |
+
loss = loss_fct(logits, data['labels'])
|
551 |
+
|
552 |
+
# Backpropagation and optimization
|
553 |
+
self.optim.zero_grad()
|
554 |
+
loss.backward()
|
555 |
+
self.optim.step()
|
556 |
+
|
557 |
+
if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
|
558 |
+
print(f"Epoch {epoch_idx}: Loss = {loss.item()}")
|
559 |
+
|
560 |
+
return loss
|
561 |
+
|
562 |
+
except Exception as e:
|
563 |
+
print(f"Error during iteration: {e}")
|
564 |
+
raise
|
565 |
+
|
566 |
+
|
567 |
+
|
568 |
+
|
569 |
+
|
570 |
+
# plt_test.show()
|
571 |
+
# print("EP%d_%s, " % (epoch, str_code))
|
572 |
+
|
573 |
+
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
574 |
+
"""
|
575 |
+
Saving the current BERT model on file_path
|
576 |
+
|
577 |
+
:param epoch: current epoch number
|
578 |
+
:param file_path: model output path which gonna be file_path+"ep%d" % epoch
|
579 |
+
:return: final_output_path
|
580 |
+
"""
|
581 |
+
if self.finetune_task:
|
582 |
+
fpath = file_path.split("/")
|
583 |
+
output_path = fpath[0]+ "/"+ fpath[1]+f"/{self.finetune_task}/" + fpath[2] + ".ep%d" % epoch
|
584 |
+
else:
|
585 |
+
output_path = file_path + ".ep%d" % epoch
|
586 |
+
torch.save(self.model.cpu(), output_path)
|
587 |
+
self.model.to(self.device)
|
588 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
589 |
+
return output_path
|
590 |
+
|
591 |
+
|
592 |
+
class BERTAttention:
|
593 |
+
def __init__(self, bert: BERT, vocab_obj, train_dataloader: DataLoader, workspace_name=None, code=None, finetune_task=None, with_cuda=True):
|
594 |
+
|
595 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
596 |
+
|
597 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
598 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
599 |
+
print(with_cuda, cuda_condition, " Device used = ", self.device)
|
600 |
+
self.bert = bert.to(self.device)
|
601 |
+
|
602 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
603 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
604 |
+
# self.bert = nn.DataParallel(self.bert, device_ids=available_gpus)
|
605 |
+
|
606 |
+
self.train_dataloader = train_dataloader
|
607 |
+
self.workspace_name = workspace_name
|
608 |
+
self.code = code
|
609 |
+
self.finetune_task = finetune_task
|
610 |
+
self.vocab_obj = vocab_obj
|
611 |
+
|
612 |
+
def getAttention(self):
|
613 |
+
# self.log_file = f"{self.workspace_name}/logs/{self.code}/log_attention.txt"
|
614 |
+
|
615 |
+
|
616 |
+
labels = ['PercentChange', 'NumeratorQuantity2', 'NumeratorQuantity1', 'DenominatorQuantity1',
|
617 |
+
'OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor',
|
618 |
+
'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow',
|
619 |
+
'ThirdRow', 'FinalAnswer','FinalAnswerDirection']
|
620 |
+
df_all = pd.DataFrame(0.0, index=labels, columns=labels)
|
621 |
+
# Setting the tqdm progress bar
|
622 |
+
data_iter = tqdm.tqdm(enumerate(self.train_dataloader),
|
623 |
+
desc="attention",
|
624 |
+
total=len(self.train_dataloader),
|
625 |
+
bar_format="{l_bar}{r_bar}")
|
626 |
+
count = 0
|
627 |
+
for i, data in data_iter:
|
628 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
629 |
+
a = self.bert.forward(data["bert_input"], data["segment_label"])
|
630 |
+
non_zero = np.sum(data["segment_label"].cpu().detach().numpy())
|
631 |
+
|
632 |
+
# Last Transformer Layer
|
633 |
+
last_layer = self.bert.attention_values[-1].transpose(1,0,2,3)
|
634 |
+
# print(last_layer.shape)
|
635 |
+
head, d_model, s, s = last_layer.shape
|
636 |
+
|
637 |
+
for d in range(d_model):
|
638 |
+
seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
639 |
+
# df_all = pd.DataFrame(0.0, index=seq_labels, columns=seq_labels)
|
640 |
+
indices_to_choose = defaultdict(int)
|
641 |
+
|
642 |
+
for k,s in enumerate(seq_labels):
|
643 |
+
if s in labels:
|
644 |
+
indices_to_choose[s] = k
|
645 |
+
indices_chosen = list(indices_to_choose.values())
|
646 |
+
selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
647 |
+
# print(len(seq_labels), len(selected_seq_labels))
|
648 |
+
for h in range(head):
|
649 |
+
# fig, ax = plt.subplots(figsize=(12, 12))
|
650 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])#[1:non_zero-1]
|
651 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
652 |
+
# indices_to_choose = defaultdict(int)
|
653 |
+
|
654 |
+
# for k,s in enumerate(seq_labels):
|
655 |
+
# if s in labels:
|
656 |
+
# indices_to_choose[s] = k
|
657 |
+
# indices_chosen = list(indices_to_choose.values())
|
658 |
+
# selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
659 |
+
# print(f"Chosen index: {seq_labels, indices_to_choose, indices_chosen, selected_seq_labels}")
|
660 |
+
|
661 |
+
df_cm = pd.DataFrame(last_layer[h][d][indices_chosen,:][:,indices_chosen], index = selected_seq_labels, columns = selected_seq_labels)
|
662 |
+
df_all = df_all.add(df_cm, fill_value=0)
|
663 |
+
count += 1
|
664 |
+
|
665 |
+
# df_cm = pd.DataFrame(last_layer[h][d][1:non_zero-1,:][:,1:non_zero-1], index=seq_labels, columns=seq_labels)
|
666 |
+
# df_all = df_all.add(df_cm, fill_value=0)
|
667 |
+
|
668 |
+
# df_all = df_all.reindex(index=seq_labels, columns=seq_labels)
|
669 |
+
# sns.heatmap(df_all, annot=False)
|
670 |
+
# plt.title("Attentions") #Probabilities
|
671 |
+
# plt.xlabel("Steps")
|
672 |
+
# plt.ylabel("Steps")
|
673 |
+
# plt.grid(True)
|
674 |
+
# plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
675 |
+
# plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores_over_[{h}]_head_n_data[{d}].png", bbox_inches='tight')
|
676 |
+
# plt.show()
|
677 |
+
# plt.close()
|
678 |
+
|
679 |
+
|
680 |
+
|
681 |
+
print(f"Count of total : {count, head * self.train_dataloader.dataset.len}")
|
682 |
+
df_all = df_all.div(count) # head * self.train_dataloader.dataset.len
|
683 |
+
df_all = df_all.reindex(index=labels, columns=labels)
|
684 |
+
sns.heatmap(df_all, annot=False)
|
685 |
+
plt.title("Attentions") #Probabilities
|
686 |
+
plt.xlabel("Steps")
|
687 |
+
plt.ylabel("Steps")
|
688 |
+
plt.grid(True)
|
689 |
+
plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
690 |
+
plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores.png", bbox_inches='tight')
|
691 |
+
plt.show()
|
692 |
+
plt.close()
|
693 |
+
|
694 |
+
|
695 |
+
|
696 |
+
|
src/reference_code/test.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, optim
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
5 |
+
import numpy as np
|
6 |
+
from keras.preprocessing.sequence import pad_sequences
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
from transformers import BertForSequenceClassification
|
9 |
+
import random
|
10 |
+
from sklearn.metrics import f1_score
|
11 |
+
from utils import *
|
12 |
+
import os
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
import warnings
|
18 |
+
warnings.filterwarnings("ignore")
|
19 |
+
|
20 |
+
class ModelWithTemperature(nn.Module):
|
21 |
+
"""
|
22 |
+
A thin decorator, which wraps a model with temperature scaling
|
23 |
+
model (nn.Module):
|
24 |
+
A classification neural network
|
25 |
+
NB: Output of the neural network should be the classification logits,
|
26 |
+
NOT the softmax (or log softmax)!
|
27 |
+
"""
|
28 |
+
def __init__(self, model):
|
29 |
+
super(ModelWithTemperature, self).__init__()
|
30 |
+
self.model = model
|
31 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
32 |
+
|
33 |
+
def forward(self, input_ids, token_type_ids, attention_mask):
|
34 |
+
logits = self.model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
|
35 |
+
return self.temperature_scale(logits)
|
36 |
+
|
37 |
+
def temperature_scale(self, logits):
|
38 |
+
"""
|
39 |
+
Perform temperature scaling on logits
|
40 |
+
"""
|
41 |
+
# Expand temperature to match the size of logits
|
42 |
+
temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
|
43 |
+
return logits / temperature
|
44 |
+
|
45 |
+
# This function probably should live outside of this class, but whatever
|
46 |
+
def set_temperature(self, valid_loader, args):
|
47 |
+
"""
|
48 |
+
Tune the tempearature of the model (using the validation set).
|
49 |
+
We're going to set it to optimize NLL.
|
50 |
+
valid_loader (DataLoader): validation set loader
|
51 |
+
"""
|
52 |
+
nll_criterion = nn.CrossEntropyLoss()
|
53 |
+
ece_criterion = ECE().to(args.device)
|
54 |
+
|
55 |
+
# First: collect all the logits and labels for the validation set
|
56 |
+
logits_list = []
|
57 |
+
labels_list = []
|
58 |
+
with torch.no_grad():
|
59 |
+
for step, batch in enumerate(valid_loader):
|
60 |
+
batch = tuple(t.to(args.device) for t in batch)
|
61 |
+
b_input_ids, b_input_mask, b_labels = batch
|
62 |
+
logits = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0]
|
63 |
+
logits_list.append(logits)
|
64 |
+
labels_list.append(b_labels)
|
65 |
+
logits = torch.cat(logits_list)
|
66 |
+
labels = torch.cat(labels_list)
|
67 |
+
|
68 |
+
# Calculate NLL and ECE before temperature scaling
|
69 |
+
before_temperature_nll = nll_criterion(logits, labels).item()
|
70 |
+
before_temperature_ece = ece_criterion(logits, labels).item()
|
71 |
+
print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))
|
72 |
+
|
73 |
+
# Next: optimize the temperature w.r.t. NLL
|
74 |
+
optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
|
75 |
+
|
76 |
+
def eval():
|
77 |
+
loss = nll_criterion(self.temperature_scale(logits), labels)
|
78 |
+
loss.backward()
|
79 |
+
return loss
|
80 |
+
optimizer.step(eval)
|
81 |
+
|
82 |
+
# Calculate NLL and ECE after temperature scaling
|
83 |
+
after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
|
84 |
+
after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
|
85 |
+
print('Optimal temperature: %.3f' % self.temperature.item())
|
86 |
+
print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
|
87 |
+
|
88 |
+
return self
|
89 |
+
|
90 |
+
class ECE(nn.Module):
|
91 |
+
|
92 |
+
def __init__(self, n_bins=15):
|
93 |
+
"""
|
94 |
+
n_bins (int): number of confidence interval bins
|
95 |
+
"""
|
96 |
+
super(ECE, self).__init__()
|
97 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
98 |
+
self.bin_lowers = bin_boundaries[:-1]
|
99 |
+
self.bin_uppers = bin_boundaries[1:]
|
100 |
+
|
101 |
+
def forward(self, logits, labels):
|
102 |
+
softmaxes = F.softmax(logits, dim=1)
|
103 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
104 |
+
accuracies = predictions.eq(labels)
|
105 |
+
|
106 |
+
ece = torch.zeros(1, device=logits.device)
|
107 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
108 |
+
# Calculated |confidence - accuracy| in each bin
|
109 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
110 |
+
prop_in_bin = in_bin.float().mean()
|
111 |
+
if prop_in_bin.item() > 0:
|
112 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
113 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
114 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
115 |
+
|
116 |
+
return ece
|
117 |
+
|
118 |
+
|
119 |
+
class ECE_v2(nn.Module):
|
120 |
+
def __init__(self, n_bins=15):
|
121 |
+
"""
|
122 |
+
n_bins (int): number of confidence interval bins
|
123 |
+
"""
|
124 |
+
super(ECE_v2, self).__init__()
|
125 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
126 |
+
self.bin_lowers = bin_boundaries[:-1]
|
127 |
+
self.bin_uppers = bin_boundaries[1:]
|
128 |
+
|
129 |
+
def forward(self, softmaxes, labels):
|
130 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
131 |
+
accuracies = predictions.eq(labels)
|
132 |
+
ece = torch.zeros(1, device=softmaxes.device)
|
133 |
+
|
134 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
135 |
+
# Calculated |confidence - accuracy| in each bin
|
136 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
137 |
+
prop_in_bin = in_bin.float().mean()
|
138 |
+
if prop_in_bin.item() > 0:
|
139 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
140 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
141 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
142 |
+
return ece
|
143 |
+
|
144 |
+
def accurate_nb(preds, labels):
|
145 |
+
pred_flat = np.argmax(preds, axis=1).flatten()
|
146 |
+
labels_flat = labels.flatten()
|
147 |
+
return np.sum(pred_flat == labels_flat)
|
148 |
+
|
149 |
+
|
150 |
+
def set_seed(args):
|
151 |
+
random.seed(args.seed)
|
152 |
+
np.random.seed(args.seed)
|
153 |
+
torch.manual_seed(args.seed)
|
154 |
+
|
155 |
+
def apply_dropout(m):
|
156 |
+
if type(m) == nn.Dropout:
|
157 |
+
m.train()
|
158 |
+
|
159 |
+
|
160 |
+
def main():
|
161 |
+
|
162 |
+
parser = argparse.ArgumentParser(description='Test code - measure the detection peformance')
|
163 |
+
parser.add_argument('--eva_iter', default=1, type=int, help='number of passes for mc-dropout when evaluation')
|
164 |
+
parser.add_argument('--model', type=str, choices=['base', 'manifold-smoothing', 'mc-dropout','temperature'], default='base')
|
165 |
+
parser.add_argument('--seed', type=int, default=0, help='random seed for test')
|
166 |
+
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs for training.")
|
167 |
+
parser.add_argument('--index', type=int, default=0, help='random seed you used during training')
|
168 |
+
parser.add_argument('--in_dataset', required=True, help='target dataset: 20news')
|
169 |
+
parser.add_argument('--out_dataset', required=True, help='out-of-dist dataset')
|
170 |
+
parser.add_argument('--eval_batch_size', type=int, default=32)
|
171 |
+
parser.add_argument('--saved_dataset', type=str, default='n')
|
172 |
+
parser.add_argument('--eps_out', default=0.001, type=float, help="Perturbation size of out-of-domain adversarial training")
|
173 |
+
parser.add_argument("--eps_y", default=0.1, type=float, help="Perturbation size of label")
|
174 |
+
parser.add_argument('--eps_in', default=0.0001, type=float, help="Perturbation size of in-domain adversarial training")
|
175 |
+
|
176 |
+
args = parser.parse_args()
|
177 |
+
|
178 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
179 |
+
args.device = device
|
180 |
+
set_seed(args)
|
181 |
+
|
182 |
+
outf = 'test/'+args.model+'-'+str(args.index)
|
183 |
+
if not os.path.isdir(outf):
|
184 |
+
os.makedirs(outf)
|
185 |
+
|
186 |
+
if args.model == 'base':
|
187 |
+
dirname = '{}/BERT-base-{}'.format(args.in_dataset, args.index)
|
188 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
189 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
190 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
191 |
+
model.to(args.device)
|
192 |
+
print('Load Tekenizer')
|
193 |
+
|
194 |
+
elif args.model == 'mc-dropout':
|
195 |
+
dirname = '{}/BERT-base-{}'.format(args.in_dataset, args.index)
|
196 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
197 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
198 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
199 |
+
model.to(args.device)
|
200 |
+
|
201 |
+
elif args.model == 'temperature':
|
202 |
+
dirname = '{}/BERT-base-{}'.format(args.in_dataset, args.index)
|
203 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
204 |
+
orig_model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
205 |
+
orig_model.to(args.device)
|
206 |
+
model = ModelWithTemperature(orig_model)
|
207 |
+
model.to(args.device)
|
208 |
+
|
209 |
+
elif args.model == 'manifold-smoothing':
|
210 |
+
dirname = '{}/BERT-mf-{}-{}-{}-{}'.format(args.in_dataset, args.index, args.eps_in, args.eps_y, args.eps_out)
|
211 |
+
print(dirname)
|
212 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
213 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
214 |
+
model.to(args.device)
|
215 |
+
|
216 |
+
|
217 |
+
if args.saved_dataset == 'n':
|
218 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
|
219 |
+
train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels = load_dataset(args.in_dataset)
|
220 |
+
_, _, nt_test_sentences, _, _, nt_test_labels = load_dataset(args.out_dataset)
|
221 |
+
|
222 |
+
val_input_ids = []
|
223 |
+
test_input_ids = []
|
224 |
+
nt_test_input_ids = []
|
225 |
+
|
226 |
+
if args.in_dataset == '20news' or args.in_dataset == '20news-15':
|
227 |
+
MAX_LEN = 150
|
228 |
+
else:
|
229 |
+
MAX_LEN = 256
|
230 |
+
|
231 |
+
for sent in val_sentences:
|
232 |
+
encoded_sent = tokenizer.encode(
|
233 |
+
sent, # Sentence to encode.
|
234 |
+
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
|
235 |
+
truncation= True,
|
236 |
+
max_length = MAX_LEN, # Truncate all sentences.
|
237 |
+
#return_tensors = 'pt', # Return pytorch tensors.
|
238 |
+
)
|
239 |
+
# Add the encoded sentence to the list.
|
240 |
+
val_input_ids.append(encoded_sent)
|
241 |
+
|
242 |
+
|
243 |
+
for sent in test_sentences:
|
244 |
+
encoded_sent = tokenizer.encode(
|
245 |
+
sent, # Sentence to encode.
|
246 |
+
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
|
247 |
+
truncation= True,
|
248 |
+
max_length = MAX_LEN, # Truncate all sentences.
|
249 |
+
#return_tensors = 'pt', # Return pytorch tensors.
|
250 |
+
)
|
251 |
+
# Add the encoded sentence to the list.
|
252 |
+
test_input_ids.append(encoded_sent)
|
253 |
+
|
254 |
+
for sent in nt_test_sentences:
|
255 |
+
encoded_sent = tokenizer.encode(
|
256 |
+
sent,
|
257 |
+
add_special_tokens = True,
|
258 |
+
truncation= True,
|
259 |
+
max_length = MAX_LEN,
|
260 |
+
)
|
261 |
+
nt_test_input_ids.append(encoded_sent)
|
262 |
+
|
263 |
+
# Pad our input tokens
|
264 |
+
val_input_ids = pad_sequences(val_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
265 |
+
test_input_ids = pad_sequences(test_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
266 |
+
nt_test_input_ids = pad_sequences(nt_test_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
267 |
+
|
268 |
+
val_attention_masks = []
|
269 |
+
test_attention_masks = []
|
270 |
+
nt_test_attention_masks = []
|
271 |
+
|
272 |
+
for seq in val_input_ids:
|
273 |
+
seq_mask = [float(i>0) for i in seq]
|
274 |
+
val_attention_masks.append(seq_mask)
|
275 |
+
for seq in test_input_ids:
|
276 |
+
seq_mask = [float(i>0) for i in seq]
|
277 |
+
test_attention_masks.append(seq_mask)
|
278 |
+
for seq in nt_test_input_ids:
|
279 |
+
seq_mask = [float(i>0) for i in seq]
|
280 |
+
nt_test_attention_masks.append(seq_mask)
|
281 |
+
|
282 |
+
|
283 |
+
val_inputs = torch.tensor(val_input_ids)
|
284 |
+
val_labels = torch.tensor(val_labels)
|
285 |
+
val_masks = torch.tensor(val_attention_masks)
|
286 |
+
|
287 |
+
test_inputs = torch.tensor(test_input_ids)
|
288 |
+
test_labels = torch.tensor(test_labels)
|
289 |
+
test_masks = torch.tensor(test_attention_masks)
|
290 |
+
|
291 |
+
nt_test_inputs = torch.tensor(nt_test_input_ids)
|
292 |
+
nt_test_labels = torch.tensor(nt_test_labels)
|
293 |
+
nt_test_masks = torch.tensor(nt_test_attention_masks)
|
294 |
+
|
295 |
+
val_data = TensorDataset(val_inputs, val_masks, val_labels)
|
296 |
+
test_data = TensorDataset(test_inputs, test_masks, test_labels)
|
297 |
+
nt_test_data = TensorDataset(nt_test_inputs, nt_test_masks, nt_test_labels)
|
298 |
+
|
299 |
+
dataset_dir = 'dataset/test'
|
300 |
+
if not os.path.exists(dataset_dir):
|
301 |
+
os.makedirs(dataset_dir)
|
302 |
+
torch.save(val_data, dataset_dir+'/{}_val_in_domain.pt'.format(args.in_dataset))
|
303 |
+
torch.save(test_data, dataset_dir+'/{}_test_in_domain.pt'.format(args.in_dataset))
|
304 |
+
torch.save(nt_test_data, dataset_dir+'/{}_test_out_of_domain.pt'.format(args.out_dataset))
|
305 |
+
|
306 |
+
else:
|
307 |
+
dataset_dir = 'dataset/test'
|
308 |
+
val_data = torch.load(dataset_dir+'/{}_val_in_domain.pt'.format(args.in_dataset))
|
309 |
+
test_data = torch.load(dataset_dir+'/{}_test_in_domain.pt'.format(args.in_dataset))
|
310 |
+
nt_test_data = torch.load(dataset_dir+'/{}_test_out_of_domain.pt'.format(args.out_dataset))
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
######## saved dataset
|
317 |
+
test_sampler = SequentialSampler(test_data)
|
318 |
+
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)
|
319 |
+
|
320 |
+
nt_test_sampler = SequentialSampler(nt_test_data)
|
321 |
+
nt_test_dataloader = DataLoader(nt_test_data, sampler=nt_test_sampler, batch_size=args.eval_batch_size)
|
322 |
+
val_sampler = SequentialSampler(val_data)
|
323 |
+
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=args.eval_batch_size)
|
324 |
+
|
325 |
+
if args.model == 'temperature':
|
326 |
+
model.set_temperature(val_dataloader, args)
|
327 |
+
|
328 |
+
model.eval()
|
329 |
+
|
330 |
+
if args.model == 'mc-dropout':
|
331 |
+
model.apply(apply_dropout)
|
332 |
+
|
333 |
+
correct = 0
|
334 |
+
total = 0
|
335 |
+
output_list = []
|
336 |
+
labels_list = []
|
337 |
+
|
338 |
+
##### validation dat
|
339 |
+
with torch.no_grad():
|
340 |
+
for step, batch in enumerate(val_dataloader):
|
341 |
+
batch = tuple(t.to(args.device) for t in batch)
|
342 |
+
b_input_ids, b_input_mask, b_labels = batch
|
343 |
+
total += b_labels.shape[0]
|
344 |
+
batch_output = 0
|
345 |
+
for j in range(args.eva_iter):
|
346 |
+
if args.model == 'temperature':
|
347 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask) #logits
|
348 |
+
else:
|
349 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0] #logits
|
350 |
+
batch_output = batch_output + F.softmax(current_batch, dim=1)
|
351 |
+
batch_output = batch_output/args.eva_iter
|
352 |
+
output_list.append(batch_output)
|
353 |
+
labels_list.append(b_labels)
|
354 |
+
score, predicted = batch_output.max(1)
|
355 |
+
correct += predicted.eq(b_labels).sum().item()
|
356 |
+
|
357 |
+
###calculate accuracy and ECE
|
358 |
+
val_eval_accuracy = correct/total
|
359 |
+
print("Val Accuracy: {}".format(val_eval_accuracy))
|
360 |
+
ece_criterion = ECE_v2().to(args.device)
|
361 |
+
softmaxes_ece = torch.cat(output_list)
|
362 |
+
labels_ece = torch.cat(labels_list)
|
363 |
+
val_ece = ece_criterion(softmaxes_ece, labels_ece).item()
|
364 |
+
print('ECE on Val data: {}'.format(val_ece))
|
365 |
+
|
366 |
+
#### Test data
|
367 |
+
correct = 0
|
368 |
+
total = 0
|
369 |
+
output_list = []
|
370 |
+
labels_list = []
|
371 |
+
predict_list = []
|
372 |
+
true_list = []
|
373 |
+
true_list_ood = []
|
374 |
+
predict_mis = []
|
375 |
+
predict_in = []
|
376 |
+
score_list = []
|
377 |
+
correct_index_all = []
|
378 |
+
## test on in-distribution test set
|
379 |
+
with torch.no_grad():
|
380 |
+
for step, batch in enumerate(test_dataloader):
|
381 |
+
batch = tuple(t.to(args.device) for t in batch)
|
382 |
+
b_input_ids, b_input_mask, b_labels = batch
|
383 |
+
total += b_labels.shape[0]
|
384 |
+
batch_output = 0
|
385 |
+
for j in range(args.eva_iter):
|
386 |
+
if args.model == 'temperature':
|
387 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask) #logits
|
388 |
+
else:
|
389 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0] #logits
|
390 |
+
batch_output = batch_output + F.softmax(current_batch, dim=1)
|
391 |
+
batch_output = batch_output/args.eva_iter
|
392 |
+
output_list.append(batch_output)
|
393 |
+
labels_list.append(b_labels)
|
394 |
+
score, predicted = batch_output.max(1)
|
395 |
+
|
396 |
+
correct += predicted.eq(b_labels).sum().item()
|
397 |
+
|
398 |
+
correct_index = (predicted == b_labels)
|
399 |
+
correct_index_all.append(correct_index)
|
400 |
+
score_list.append(score)
|
401 |
+
|
402 |
+
###calcutae accuracy
|
403 |
+
eval_accuracy = correct/total
|
404 |
+
print("Test Accuracy: {}".format(eval_accuracy))
|
405 |
+
|
406 |
+
##calculate ece
|
407 |
+
ece_criterion = ECE_v2().to(args.device)
|
408 |
+
softmaxes_ece = torch.cat(output_list)
|
409 |
+
labels_ece = torch.cat(labels_list)
|
410 |
+
ece = ece_criterion(softmaxes_ece, labels_ece).item()
|
411 |
+
print('ECE on Test data: {}'.format(ece))
|
412 |
+
|
413 |
+
#confidence for in-distribution data
|
414 |
+
score_in_array = torch.cat(score_list)
|
415 |
+
#indices of data that are classified correctly
|
416 |
+
correct_array = torch.cat(correct_index_all)
|
417 |
+
label_array = torch.cat(labels_list)
|
418 |
+
|
419 |
+
### test on out-of-distribution data
|
420 |
+
predict_ood = []
|
421 |
+
score_ood_list = []
|
422 |
+
true_list_ood = []
|
423 |
+
with torch.no_grad():
|
424 |
+
for step, batch in enumerate(nt_test_dataloader):
|
425 |
+
batch = tuple(t.to(args.device) for t in batch)
|
426 |
+
b_input_ids, b_input_mask, b_labels = batch
|
427 |
+
batch_output = 0
|
428 |
+
for j in range(args.eva_iter):
|
429 |
+
if args.model == 'temperature':
|
430 |
+
current_batch = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
|
431 |
+
else:
|
432 |
+
current_batch = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0]
|
433 |
+
batch_output = batch_output + F.softmax(current_batch, dim=1)
|
434 |
+
batch_output = batch_output/args.eva_iter
|
435 |
+
score_out, _ = batch_output.max(1)
|
436 |
+
|
437 |
+
score_ood_list.append(score_out)
|
438 |
+
|
439 |
+
score_ood_array = torch.cat(score_ood_list)
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
label_array = label_array.cpu().numpy()
|
444 |
+
score_ood_array = score_ood_array.cpu().numpy()
|
445 |
+
score_in_array = score_in_array.cpu().numpy()
|
446 |
+
correct_array = correct_array.cpu().numpy()
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
|
451 |
+
####### calculate NBAUCC for detection task
|
452 |
+
predict_o = np.zeros(len(score_in_array)+len(score_ood_array))
|
453 |
+
true_o = np.ones(len(score_in_array)+len(score_ood_array))
|
454 |
+
true_o[:len(score_in_array)] = 0 ## in-distribution data as false, ood data as positive
|
455 |
+
true_mis = np.ones(len(score_in_array))
|
456 |
+
true_mis[correct_array] = 0 ##true instances as false, misclassified instances as positive
|
457 |
+
predict_mis = np.zeros(len(score_in_array))
|
458 |
+
|
459 |
+
|
460 |
+
|
461 |
+
ood_sum = 0
|
462 |
+
mis_sum = 0
|
463 |
+
|
464 |
+
ood_sum_list = []
|
465 |
+
mis_sum_list = []
|
466 |
+
|
467 |
+
#### upper bound of the threshold tau for NBAUCC
|
468 |
+
stop_points = [0.50, 1.]
|
469 |
+
|
470 |
+
for threshold in np.arange(0., 1.01, 0.02):
|
471 |
+
predict_ood_index1 = (score_in_array < threshold)
|
472 |
+
predict_ood_index2 = (score_ood_array < threshold)
|
473 |
+
predict_ood_index = np.concatenate((predict_ood_index1, predict_ood_index2), axis=0)
|
474 |
+
predict_o[predict_ood_index] = 1
|
475 |
+
predict_mis[score_in_array<threshold] = 1
|
476 |
+
|
477 |
+
ood = f1_score(true_o, predict_o, average='binary') ##### detection f1 score for a specific threshold
|
478 |
+
mis = f1_score(true_mis, predict_mis, average='binary')
|
479 |
+
|
480 |
+
|
481 |
+
ood_sum += ood*0.02
|
482 |
+
mis_sum += mis*0.02
|
483 |
+
|
484 |
+
if threshold in stop_points:
|
485 |
+
ood_sum_list.append(ood_sum)
|
486 |
+
mis_sum_list.append(mis_sum)
|
487 |
+
|
488 |
+
for i in range(len(stop_points)):
|
489 |
+
print('OOD detection, NBAUCC {}: {}'.format(stop_points[i], ood_sum_list[i]/stop_points[i]))
|
490 |
+
print('misclassification detection, NBAUCC {}: {}'.format(stop_points[i], mis_sum_list[i]/stop_points[i]))
|
491 |
+
|
492 |
+
if __name__ == "__main__":
|
493 |
+
main()
|
src/reference_code/utils.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import pandas as pd
|
5 |
+
from collections import Counter
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.datasets import fetch_20newsgroups
|
8 |
+
from collections import Counter, defaultdict
|
9 |
+
from nltk.corpus import stopwords
|
10 |
+
from sklearn.model_selection import train_test_split
|
11 |
+
import re
|
12 |
+
from sklearn.utils import shuffle
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def cos_dist(x, y):
|
17 |
+
## cosine distance function
|
18 |
+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
19 |
+
batch_size = x.size(0)
|
20 |
+
c = torch.clamp(1 - cos(x.view(batch_size, -1), y.view(batch_size, -1)),
|
21 |
+
min=0)
|
22 |
+
return c.mean()
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def tag_mapping(tags):
|
28 |
+
"""
|
29 |
+
Create a dictionary and a mapping of tags, sorted by frequency.
|
30 |
+
"""
|
31 |
+
#tags = [s[1] for s in dataset]
|
32 |
+
dico = Counter(tags)
|
33 |
+
tag_to_id, id_to_tag = create_mapping(dico)
|
34 |
+
print("Found %i unique named entity tags" % len(dico))
|
35 |
+
return dico, tag_to_id, id_to_tag
|
36 |
+
|
37 |
+
|
38 |
+
def create_mapping(dico):
|
39 |
+
"""
|
40 |
+
Create a mapping (item to ID / ID to item) from a dictionary.
|
41 |
+
Items are ordered by decreasing frequency.
|
42 |
+
"""
|
43 |
+
sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
|
44 |
+
id_to_item = {i: v[0] for i, v in enumerate(sorted_items)}
|
45 |
+
item_to_id = {v: k for k, v in id_to_item.items()}
|
46 |
+
return item_to_id, id_to_item
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def clean_str(string):
|
52 |
+
"""
|
53 |
+
Tokenization/string cleaning for all datasets except for SST.
|
54 |
+
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
|
55 |
+
"""
|
56 |
+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
|
57 |
+
string = re.sub(r"\'s", " \'s", string)
|
58 |
+
string = re.sub(r"\'ve", " \'ve", string)
|
59 |
+
string = re.sub(r"n\'t", " n\'t", string)
|
60 |
+
string = re.sub(r"\'re", " \'re", string)
|
61 |
+
string = re.sub(r"\'d", " \'d", string)
|
62 |
+
string = re.sub(r"\'ll", " \'ll", string)
|
63 |
+
string = re.sub(r",", " , ", string)
|
64 |
+
string = re.sub(r"!", " ! ", string)
|
65 |
+
string = re.sub(r"\(", " \( ", string)
|
66 |
+
string = re.sub(r"\)", " \) ", string)
|
67 |
+
string = re.sub(r"\?", " \? ", string)
|
68 |
+
string = re.sub(r"\s{2,}", " ", string)
|
69 |
+
return string.strip().lower()
|
70 |
+
|
71 |
+
|
72 |
+
def clean_doc(x, word_freq):
|
73 |
+
stop_words = set(stopwords.words('english'))
|
74 |
+
clean_docs = []
|
75 |
+
most_commons = dict(word_freq.most_common(min(len(word_freq), 50000)))
|
76 |
+
for doc_content in x:
|
77 |
+
doc_words = []
|
78 |
+
cleaned = clean_str(doc_content.strip())
|
79 |
+
for word in cleaned.split():
|
80 |
+
if word not in stop_words and word_freq[word] >= 5:
|
81 |
+
if word in most_commons:
|
82 |
+
doc_words.append(word)
|
83 |
+
else:
|
84 |
+
doc_words.append("<UNK>")
|
85 |
+
doc_str = ' '.join(doc_words).strip()
|
86 |
+
clean_docs.append(doc_str)
|
87 |
+
return clean_docs
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def load_dataset(dataset):
|
92 |
+
|
93 |
+
if dataset == 'sst':
|
94 |
+
df_train = pd.read_csv("./dataset/sst/SST-2/train.tsv", delimiter='\t', header=0)
|
95 |
+
|
96 |
+
df_val = pd.read_csv("./dataset/sst/SST-2/dev.tsv", delimiter='\t', header=0)
|
97 |
+
|
98 |
+
df_test = pd.read_csv("./dataset/sst/SST-2/sst-test.tsv", delimiter='\t', header=None, names=['sentence', 'label'])
|
99 |
+
|
100 |
+
train_sentences = df_train.sentence.values
|
101 |
+
val_sentences = df_val.sentence.values
|
102 |
+
test_sentences = df_test.sentence.values
|
103 |
+
train_labels = df_train.label.values
|
104 |
+
val_labels = df_val.label.values
|
105 |
+
test_labels = df_test.label.values
|
106 |
+
|
107 |
+
|
108 |
+
if dataset == '20news':
|
109 |
+
|
110 |
+
VALIDATION_SPLIT = 0.8
|
111 |
+
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, random_state=0)
|
112 |
+
print(newsgroups_train.target_names)
|
113 |
+
print(len(newsgroups_train.data))
|
114 |
+
|
115 |
+
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False)
|
116 |
+
|
117 |
+
print(len(newsgroups_test.data))
|
118 |
+
|
119 |
+
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))
|
120 |
+
|
121 |
+
train_sentences = newsgroups_train.data[:train_len]
|
122 |
+
val_sentences = newsgroups_train.data[train_len:]
|
123 |
+
test_sentences = newsgroups_test.data
|
124 |
+
train_labels = newsgroups_train.target[:train_len]
|
125 |
+
val_labels = newsgroups_train.target[train_len:]
|
126 |
+
test_labels = newsgroups_test.target
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
if dataset == '20news-15':
|
131 |
+
VALIDATION_SPLIT = 0.8
|
132 |
+
cats = ['alt.atheism',
|
133 |
+
'comp.graphics',
|
134 |
+
'comp.os.ms-windows.misc',
|
135 |
+
'comp.sys.ibm.pc.hardware',
|
136 |
+
'comp.sys.mac.hardware',
|
137 |
+
'comp.windows.x',
|
138 |
+
'rec.autos',
|
139 |
+
'rec.motorcycles',
|
140 |
+
'rec.sport.baseball',
|
141 |
+
'rec.sport.hockey',
|
142 |
+
'misc.forsale',
|
143 |
+
'sci.crypt',
|
144 |
+
'sci.electronics',
|
145 |
+
'sci.med',
|
146 |
+
'sci.space']
|
147 |
+
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, categories=cats, random_state=0)
|
148 |
+
print(newsgroups_train.target_names)
|
149 |
+
print(len(newsgroups_train.data))
|
150 |
+
|
151 |
+
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats)
|
152 |
+
|
153 |
+
print(len(newsgroups_test.data))
|
154 |
+
|
155 |
+
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))
|
156 |
+
|
157 |
+
train_sentences = newsgroups_train.data[:train_len]
|
158 |
+
val_sentences = newsgroups_train.data[train_len:]
|
159 |
+
test_sentences = newsgroups_test.data
|
160 |
+
train_labels = newsgroups_train.target[:train_len]
|
161 |
+
val_labels = newsgroups_train.target[train_len:]
|
162 |
+
test_labels = newsgroups_test.target
|
163 |
+
|
164 |
+
|
165 |
+
if dataset == '20news-5':
|
166 |
+
cats = [
|
167 |
+
'soc.religion.christian',
|
168 |
+
'talk.politics.guns',
|
169 |
+
'talk.politics.mideast',
|
170 |
+
'talk.politics.misc',
|
171 |
+
'talk.religion.misc']
|
172 |
+
|
173 |
+
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats)
|
174 |
+
print(newsgroups_test.target_names)
|
175 |
+
print(len(newsgroups_test.data))
|
176 |
+
|
177 |
+
train_sentences = None
|
178 |
+
val_sentences = None
|
179 |
+
test_sentences = newsgroups_test.data
|
180 |
+
train_labels = None
|
181 |
+
val_labels = None
|
182 |
+
test_labels = newsgroups_test.target
|
183 |
+
|
184 |
+
if dataset == 'wos':
|
185 |
+
TESTING_SPLIT = 0.6
|
186 |
+
VALIDATION_SPLIT = 0.8
|
187 |
+
file_path = './dataset/WebOfScience/WOS46985/X.txt'
|
188 |
+
with open(file_path, 'r') as read_file:
|
189 |
+
x_temp = read_file.readlines()
|
190 |
+
x_all = []
|
191 |
+
for x in x_temp:
|
192 |
+
x_all.append(str(x))
|
193 |
+
|
194 |
+
print(len(x_all))
|
195 |
+
|
196 |
+
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
|
197 |
+
with open(file_path, 'r') as read_file:
|
198 |
+
y_temp= read_file.readlines()
|
199 |
+
y_all = []
|
200 |
+
for y in y_temp:
|
201 |
+
y_all.append(int(y))
|
202 |
+
print(len(y_all))
|
203 |
+
print(max(y_all), min(y_all))
|
204 |
+
|
205 |
+
|
206 |
+
x_in = []
|
207 |
+
y_in = []
|
208 |
+
for i in range(len(x_all)):
|
209 |
+
x_in.append(x_all[i])
|
210 |
+
y_in.append(y_all[i])
|
211 |
+
|
212 |
+
|
213 |
+
train_val_len = int(TESTING_SPLIT * len(x_in))
|
214 |
+
train_len = int(VALIDATION_SPLIT * train_val_len)
|
215 |
+
|
216 |
+
train_sentences = x_in[:train_len]
|
217 |
+
val_sentences = x_in[train_len:train_val_len]
|
218 |
+
test_sentences = x_in[train_val_len:]
|
219 |
+
|
220 |
+
train_labels = y_in[:train_len]
|
221 |
+
val_labels = y_in[train_len:train_val_len]
|
222 |
+
test_labels = y_in[train_val_len:]
|
223 |
+
|
224 |
+
print(len(train_labels))
|
225 |
+
print(len(val_labels))
|
226 |
+
print(len(test_labels))
|
227 |
+
|
228 |
+
|
229 |
+
if dataset == 'wos-100':
|
230 |
+
TESTING_SPLIT = 0.6
|
231 |
+
VALIDATION_SPLIT = 0.8
|
232 |
+
file_path = './dataset/WebOfScience/WOS46985/X.txt'
|
233 |
+
with open(file_path, 'r') as read_file:
|
234 |
+
x_temp = read_file.readlines()
|
235 |
+
x_all = []
|
236 |
+
for x in x_temp:
|
237 |
+
x_all.append(str(x))
|
238 |
+
|
239 |
+
print(len(x_all))
|
240 |
+
|
241 |
+
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
|
242 |
+
with open(file_path, 'r') as read_file:
|
243 |
+
y_temp= read_file.readlines()
|
244 |
+
y_all = []
|
245 |
+
for y in y_temp:
|
246 |
+
y_all.append(int(y))
|
247 |
+
print(len(y_all))
|
248 |
+
print(max(y_all), min(y_all))
|
249 |
+
|
250 |
+
|
251 |
+
x_in = []
|
252 |
+
y_in = []
|
253 |
+
for i in range(len(x_all)):
|
254 |
+
if y_all[i] in range(100):
|
255 |
+
x_in.append(x_all[i])
|
256 |
+
y_in.append(y_all[i])
|
257 |
+
|
258 |
+
for i in range(133):
|
259 |
+
num = 0
|
260 |
+
for y in y_in:
|
261 |
+
if y == i:
|
262 |
+
num = num + 1
|
263 |
+
# print(num)
|
264 |
+
|
265 |
+
train_val_len = int(TESTING_SPLIT * len(x_in))
|
266 |
+
train_len = int(VALIDATION_SPLIT * train_val_len)
|
267 |
+
|
268 |
+
train_sentences = x_in[:train_len]
|
269 |
+
val_sentences = x_in[train_len:train_val_len]
|
270 |
+
test_sentences = x_in[train_val_len:]
|
271 |
+
|
272 |
+
train_labels = y_in[:train_len]
|
273 |
+
val_labels = y_in[train_len:train_val_len]
|
274 |
+
test_labels = y_in[train_val_len:]
|
275 |
+
|
276 |
+
print(len(train_labels))
|
277 |
+
print(len(val_labels))
|
278 |
+
print(len(test_labels))
|
279 |
+
|
280 |
+
if dataset == 'wos-34':
|
281 |
+
TESTING_SPLIT = 0.6
|
282 |
+
VALIDATION_SPLIT = 0.8
|
283 |
+
file_path = './dataset/WebOfScience/WOS46985/X.txt'
|
284 |
+
with open(file_path, 'r') as read_file:
|
285 |
+
x_temp = read_file.readlines()
|
286 |
+
x_all = []
|
287 |
+
for x in x_temp:
|
288 |
+
x_all.append(str(x))
|
289 |
+
|
290 |
+
print(len(x_all))
|
291 |
+
|
292 |
+
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
|
293 |
+
with open(file_path, 'r') as read_file:
|
294 |
+
y_temp= read_file.readlines()
|
295 |
+
y_all = []
|
296 |
+
for y in y_temp:
|
297 |
+
y_all.append(int(y))
|
298 |
+
print(len(y_all))
|
299 |
+
print(max(y_all), min(y_all))
|
300 |
+
|
301 |
+
x_in = []
|
302 |
+
y_in = []
|
303 |
+
for i in range(len(x_all)):
|
304 |
+
if (y_all[i] in range(100)) != True:
|
305 |
+
x_in.append(x_all[i])
|
306 |
+
y_in.append(y_all[i])
|
307 |
+
|
308 |
+
for i in range(133):
|
309 |
+
num = 0
|
310 |
+
for y in y_in:
|
311 |
+
if y == i:
|
312 |
+
num = num + 1
|
313 |
+
# print(num)
|
314 |
+
|
315 |
+
train_val_len = int(TESTING_SPLIT * len(x_in))
|
316 |
+
train_len = int(VALIDATION_SPLIT * train_val_len)
|
317 |
+
|
318 |
+
train_sentences = None
|
319 |
+
val_sentences = None
|
320 |
+
test_sentences = x_in[train_val_len:]
|
321 |
+
|
322 |
+
train_labels = None
|
323 |
+
val_labels = None
|
324 |
+
test_labels = y_in[train_val_len:]
|
325 |
+
|
326 |
+
print(len(test_labels))
|
327 |
+
|
328 |
+
if dataset == 'agnews':
|
329 |
+
|
330 |
+
VALIDATION_SPLIT = 0.8
|
331 |
+
labels_in_domain = [1, 2]
|
332 |
+
|
333 |
+
train_df = pd.read_csv('./dataset/agnews/train.csv', header=None)
|
334 |
+
train_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
|
335 |
+
# train_df = pd.concat([train_df, pd.get_dummies(train_df['label'],prefix='label')], axis=1)
|
336 |
+
print(train_df.dtypes)
|
337 |
+
train_in_df_sentence = []
|
338 |
+
train_in_df_label = []
|
339 |
+
|
340 |
+
for i in range(len(train_df.sentence.values)):
|
341 |
+
sentence_temp = ''.join(str(train_df.sentence.values[i]))
|
342 |
+
train_in_df_sentence.append(sentence_temp)
|
343 |
+
train_in_df_label.append(train_df.label.values[i]-1)
|
344 |
+
|
345 |
+
test_df = pd.read_csv('./dataset/agnews/test.csv', header=None)
|
346 |
+
test_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
|
347 |
+
# test_df = pd.concat([test_df, pd.get_dummies(test_df['label'],prefix='label')], axis=1)
|
348 |
+
test_in_df_sentence = []
|
349 |
+
test_in_df_label = []
|
350 |
+
for i in range(len(test_df.sentence.values)):
|
351 |
+
test_in_df_sentence.append(str(test_df.sentence.values[i]))
|
352 |
+
test_in_df_label.append(test_df.label.values[i]-1)
|
353 |
+
|
354 |
+
train_len = int(VALIDATION_SPLIT * len(train_in_df_sentence))
|
355 |
+
|
356 |
+
train_sentences = train_in_df_sentence[:train_len]
|
357 |
+
val_sentences = train_in_df_sentence[train_len:]
|
358 |
+
test_sentences = test_in_df_sentence
|
359 |
+
train_labels = train_in_df_label[:train_len]
|
360 |
+
val_labels = train_in_df_label[train_len:]
|
361 |
+
test_labels = test_in_df_label
|
362 |
+
print(len(train_sentences))
|
363 |
+
print(len(val_sentences))
|
364 |
+
print(len(test_sentences))
|
365 |
+
|
366 |
+
|
367 |
+
return train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels
|
368 |
+
|
369 |
+
|
src/reference_code/visualization.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
#import matplotlib as mpl
|
3 |
+
#mpl.use('Agg')
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
import metrics
|
7 |
+
|
8 |
+
class ConfidenceHistogram(metrics.MaxProbCELoss):
|
9 |
+
|
10 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
11 |
+
super().loss(output, labels, n_bins, logits)
|
12 |
+
#scale each datapoint
|
13 |
+
n = len(labels)
|
14 |
+
w = np.ones(n)/n
|
15 |
+
|
16 |
+
plt.rcParams["font.family"] = "serif"
|
17 |
+
#size and axis limits
|
18 |
+
plt.figure(figsize=(3,3))
|
19 |
+
plt.xlim(0,1)
|
20 |
+
plt.ylim(0,1)
|
21 |
+
plt.xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
22 |
+
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
23 |
+
#plot grid
|
24 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
25 |
+
#plot histogram
|
26 |
+
plt.hist(self.confidences,n_bins,weights = w,color='b',range=(0.0,1.0),edgecolor = 'k')
|
27 |
+
|
28 |
+
#plot vertical dashed lines
|
29 |
+
acc = np.mean(self.accuracies)
|
30 |
+
conf = np.mean(self.confidences)
|
31 |
+
plt.axvline(x=acc, color='tab:grey', linestyle='--', linewidth = 3)
|
32 |
+
plt.axvline(x=conf, color='tab:grey', linestyle='--', linewidth = 3)
|
33 |
+
if acc > conf:
|
34 |
+
plt.text(acc+0.03,0.9,'Accuracy',rotation=90,fontsize=11)
|
35 |
+
plt.text(conf-0.07,0.9,'Avg. Confidence',rotation=90, fontsize=11)
|
36 |
+
else:
|
37 |
+
plt.text(acc-0.07,0.9,'Accuracy',rotation=90,fontsize=11)
|
38 |
+
plt.text(conf+0.03,0.9,'Avg. Confidence',rotation=90, fontsize=11)
|
39 |
+
|
40 |
+
plt.ylabel('% of Samples',fontsize=13)
|
41 |
+
plt.xlabel('Confidence',fontsize=13)
|
42 |
+
plt.tight_layout()
|
43 |
+
if title is not None:
|
44 |
+
plt.title(title,fontsize=16)
|
45 |
+
return plt
|
46 |
+
|
47 |
+
class ReliabilityDiagram(metrics.MaxProbCELoss):
|
48 |
+
|
49 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
50 |
+
super().loss(output, labels, n_bins, logits)
|
51 |
+
|
52 |
+
#computations
|
53 |
+
delta = 1.0/n_bins
|
54 |
+
x = np.arange(0,1,delta)
|
55 |
+
mid = np.linspace(delta/2,1-delta/2,n_bins)
|
56 |
+
error = np.abs(np.subtract(mid,self.bin_acc))
|
57 |
+
|
58 |
+
plt.rcParams["font.family"] = "serif"
|
59 |
+
#size and axis limits
|
60 |
+
plt.figure(figsize=(3,3))
|
61 |
+
plt.xlim(0,1)
|
62 |
+
plt.ylim(0,1)
|
63 |
+
#plot grid
|
64 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
65 |
+
#plot bars and identity line
|
66 |
+
plt.bar(x, self.bin_acc, color = 'b', width=delta,align='edge',edgecolor = 'k',label='Outputs',zorder=5)
|
67 |
+
plt.bar(x, error, bottom=np.minimum(self.bin_acc,mid), color = 'mistyrose', alpha=0.5, width=delta,align='edge',edgecolor = 'r',hatch='/',label='Gap',zorder=10)
|
68 |
+
ident = [0.0, 1.0]
|
69 |
+
plt.plot(ident,ident,linestyle='--',color='tab:grey',zorder=15)
|
70 |
+
#labels and legend
|
71 |
+
plt.ylabel('Accuracy',fontsize=13)
|
72 |
+
plt.xlabel('Confidence',fontsize=13)
|
73 |
+
plt.legend(loc='upper left',framealpha=1.0,fontsize='medium')
|
74 |
+
if title is not None:
|
75 |
+
plt.title(title,fontsize=16)
|
76 |
+
plt.tight_layout()
|
77 |
+
|
78 |
+
return plt
|
src/seq_model.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
import torch.nn as nn
|
2 |
|
|
|
|
|
|
|
3 |
from bert import BERT
|
|
|
4 |
|
5 |
|
6 |
class BERTSM(nn.Module):
|
@@ -18,6 +22,12 @@ class BERTSM(nn.Module):
|
|
18 |
super().__init__()
|
19 |
self.bert = bert
|
20 |
self.mask_lm = MaskedSequenceModel(self.bert.hidden, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
self.same_student = SameStudentPrediction(self.bert.hidden)
|
22 |
|
23 |
def forward(self, x, segment_label, pred=False):
|
@@ -28,6 +38,7 @@ class BERTSM(nn.Module):
|
|
28 |
return x[:, 0], self.mask_lm(x), self.same_student(x)
|
29 |
else:
|
30 |
return x[:, 0], self.mask_lm(x)
|
|
|
31 |
|
32 |
|
33 |
class MaskedSequenceModel(nn.Module):
|
@@ -46,6 +57,9 @@ class MaskedSequenceModel(nn.Module):
|
|
46 |
self.softmax = nn.LogSoftmax(dim=-1)
|
47 |
|
48 |
def forward(self, x):
|
|
|
|
|
|
|
49 |
return self.softmax(self.linear(x))
|
50 |
|
51 |
|
@@ -62,3 +76,4 @@ class SameStudentPrediction(nn.Module):
|
|
62 |
def forward(self, x):
|
63 |
return self.softmax(self.linear(x[:, 0]))
|
64 |
|
|
|
|
1 |
import torch.nn as nn
|
2 |
|
3 |
+
<<<<<<< HEAD
|
4 |
+
from .bert import BERT
|
5 |
+
=======
|
6 |
from bert import BERT
|
7 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
8 |
|
9 |
|
10 |
class BERTSM(nn.Module):
|
|
|
22 |
super().__init__()
|
23 |
self.bert = bert
|
24 |
self.mask_lm = MaskedSequenceModel(self.bert.hidden, vocab_size)
|
25 |
+
<<<<<<< HEAD
|
26 |
+
|
27 |
+
def forward(self, x, segment_label):
|
28 |
+
x = self.bert(x, segment_label)
|
29 |
+
return self.mask_lm(x), x[:, 0]
|
30 |
+
=======
|
31 |
self.same_student = SameStudentPrediction(self.bert.hidden)
|
32 |
|
33 |
def forward(self, x, segment_label, pred=False):
|
|
|
38 |
return x[:, 0], self.mask_lm(x), self.same_student(x)
|
39 |
else:
|
40 |
return x[:, 0], self.mask_lm(x)
|
41 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
42 |
|
43 |
|
44 |
class MaskedSequenceModel(nn.Module):
|
|
|
57 |
self.softmax = nn.LogSoftmax(dim=-1)
|
58 |
|
59 |
def forward(self, x):
|
60 |
+
<<<<<<< HEAD
|
61 |
+
return self.softmax(self.linear(x))
|
62 |
+
=======
|
63 |
return self.softmax(self.linear(x))
|
64 |
|
65 |
|
|
|
76 |
def forward(self, x):
|
77 |
return self.softmax(self.linear(x[:, 0]))
|
78 |
|
79 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/transformer.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import torch.nn as nn
|
2 |
|
|
|
|
|
|
|
|
|
3 |
from attention import MultiHeadedAttention
|
4 |
from transformer_component import SublayerConnection, PositionwiseFeedForward
|
|
|
5 |
|
6 |
class TransformerBlock(nn.Module):
|
7 |
"""
|
@@ -25,6 +30,12 @@ class TransformerBlock(nn.Module):
|
|
25 |
self.dropout = nn.Dropout(p=dropout)
|
26 |
|
27 |
def forward(self, x, mask):
|
|
|
|
|
|
|
|
|
|
|
28 |
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
|
|
|
29 |
x = self.output_sublayer(x, self.feed_forward)
|
30 |
return self.dropout(x)
|
|
|
1 |
import torch.nn as nn
|
2 |
|
3 |
+
<<<<<<< HEAD
|
4 |
+
from .attention import MultiHeadedAttention
|
5 |
+
from .transformer_component import SublayerConnection, PositionwiseFeedForward
|
6 |
+
=======
|
7 |
from attention import MultiHeadedAttention
|
8 |
from transformer_component import SublayerConnection, PositionwiseFeedForward
|
9 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
10 |
|
11 |
class TransformerBlock(nn.Module):
|
12 |
"""
|
|
|
30 |
self.dropout = nn.Dropout(p=dropout)
|
31 |
|
32 |
def forward(self, x, mask):
|
33 |
+
<<<<<<< HEAD
|
34 |
+
attn_output, p_attn = self.attention.forward(x, x, x, mask=mask)
|
35 |
+
self.p_attn = p_attn.cpu().detach().numpy()
|
36 |
+
x = self.input_sublayer(x, lambda _x: attn_output)
|
37 |
+
=======
|
38 |
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
|
39 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
40 |
x = self.output_sublayer(x, self.feed_forward)
|
41 |
return self.dropout(x)
|
src/vocab.py
CHANGED
@@ -1,9 +1,22 @@
|
|
1 |
import collections
|
2 |
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class Vocab(object):
|
5 |
"""
|
6 |
Special tokens predefined in the vocab file are:
|
|
|
|
|
|
|
|
|
7 |
-[UNK]
|
8 |
-[MASK]
|
9 |
-[CLS]
|
@@ -35,7 +48,11 @@ class Vocab(object):
|
|
35 |
words = [self.invocab[index] if index < len(self.invocab)
|
36 |
else "[%d]" % index for index in seq ]
|
37 |
|
|
|
|
|
|
|
38 |
return " ".join(words)
|
|
|
39 |
|
40 |
|
41 |
# if __init__ == "__main__":
|
|
|
1 |
import collections
|
2 |
import tqdm
|
3 |
+
<<<<<<< HEAD
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
head_directory = Path(__file__).resolve().parent.parent
|
8 |
+
# print(head_directory)
|
9 |
+
os.chdir(head_directory)
|
10 |
+
=======
|
11 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
12 |
|
13 |
class Vocab(object):
|
14 |
"""
|
15 |
Special tokens predefined in the vocab file are:
|
16 |
+
<<<<<<< HEAD
|
17 |
+
-[PAD]
|
18 |
+
=======
|
19 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
20 |
-[UNK]
|
21 |
-[MASK]
|
22 |
-[CLS]
|
|
|
48 |
words = [self.invocab[index] if index < len(self.invocab)
|
49 |
else "[%d]" % index for index in seq ]
|
50 |
|
51 |
+
<<<<<<< HEAD
|
52 |
+
return words #" ".join(words)
|
53 |
+
=======
|
54 |
return " ".join(words)
|
55 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
56 |
|
57 |
|
58 |
# if __init__ == "__main__":
|
test.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
subprocess.run([
|
3 |
+
"python", "new_test_saved_finetuned_model.py",
|
4 |
+
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
|
5 |
+
"-finetune_task", "highGRschool10",
|
6 |
+
"-finetuned_bert_classifier_checkpoint",
|
7 |
+
"ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42"
|
8 |
+
])
|
test.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
test_hint_fine_tuned.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from src.vocab import Vocab
|
4 |
+
from src.dataset import TokenizerDataset
|
5 |
+
from hint_fine_tuning import CustomBERTModel
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
def test_model(opt):
|
9 |
+
print(f"Loading Vocab {opt.vocab_path}")
|
10 |
+
vocab = Vocab(opt.vocab_path)
|
11 |
+
vocab.load_vocab()
|
12 |
+
|
13 |
+
print(f"Vocab Size: {len(vocab.vocab)}")
|
14 |
+
|
15 |
+
test_dataset = TokenizerDataset(opt.test_dataset, opt.test_label, vocab, seq_len=50) # Using sequence length 50
|
16 |
+
print(f"Creating Dataloader")
|
17 |
+
test_data_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)
|
18 |
+
|
19 |
+
# Load the entire fine-tuned model (including both architecture and weights)
|
20 |
+
print(f"Loading Model from {opt.finetuned_bert_checkpoint}")
|
21 |
+
model = torch.load(opt.finetuned_bert_checkpoint, map_location="cpu")
|
22 |
+
|
23 |
+
print(f"Number of Labels: {opt.num_labels}")
|
24 |
+
|
25 |
+
model.eval()
|
26 |
+
for batch_idx, data in enumerate(test_data_loader):
|
27 |
+
inputs = data["input"].to("cpu")
|
28 |
+
segment_info = data["segment_label"].to("cpu")
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
logits = model(inputs, segment_info)
|
32 |
+
|
33 |
+
print(f"Batch {batch_idx} logits: {logits}")
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
|
38 |
+
parser.add_argument("-t", "--test_dataset", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_test_dataset.csv", help="test set for evaluating fine-tuned model")
|
39 |
+
parser.add_argument("-tlabel", "--test_label", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/test_infos_only.csv", help="label set for evaluating fine-tuned model")
|
40 |
+
parser.add_argument("-c", "--finetuned_bert_checkpoint", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification/fine_tuned_model_2.pth", help="checkpoint of the saved fine-tuned BERT model")
|
41 |
+
parser.add_argument("-v", "--vocab_path", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt", help="built vocab model path")
|
42 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
43 |
+
|
44 |
+
opt = parser.parse_args()
|
45 |
+
test_model(opt)
|
test_saved_model.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import torch.nn as nn
|
2 |
+
# import torch
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.optim import Adam, SGD
|
9 |
+
import torch
|
10 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
11 |
+
|
12 |
+
from src.pretrainer import BERTFineTuneTrainer1
|
13 |
+
from src.dataset import TokenizerDataset
|
14 |
+
from src.vocab import Vocab
|
15 |
+
|
16 |
+
import tqdm
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
import time
|
20 |
+
from src.bert import BERT
|
21 |
+
from hint_fine_tuning import CustomBERTModel
|
22 |
+
|
23 |
+
# from vocab import Vocab
|
24 |
+
|
25 |
+
# class BERTForSequenceClassification(nn.Module):
|
26 |
+
# """
|
27 |
+
# Since its classification,
|
28 |
+
# n_labels = 2
|
29 |
+
# """
|
30 |
+
|
31 |
+
# def __init__(self, vocab_size, n_labels, layers=None, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
|
32 |
+
# super().__init__()
|
33 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
34 |
+
# print(device)
|
35 |
+
# # model_ep0 = torch.load("output_1/bert_trained.model.ep0", map_location=device)
|
36 |
+
# self.bert = torch.load("output_1/bert_trained.model.ep0", map_location=device)
|
37 |
+
# self.dropout = nn.Dropout(dropout)
|
38 |
+
# # add an output layer
|
39 |
+
# self.
|
40 |
+
|
41 |
+
# def forward(self, x, segment_info):
|
42 |
+
|
43 |
+
|
44 |
+
# return x
|
45 |
+
|
46 |
+
|
47 |
+
class BERTFineTunedTrainer:
|
48 |
+
|
49 |
+
def __init__(self, bert: CustomBERTModel, vocab_size: int,
|
50 |
+
train_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
|
51 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
52 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2):
|
53 |
+
"""
|
54 |
+
:param bert: BERT model which you want to train
|
55 |
+
:param vocab_size: total word vocab size
|
56 |
+
:param train_dataloader: train dataset data loader
|
57 |
+
:param test_dataloader: test dataset data loader [can be None]
|
58 |
+
:param lr: learning rate of optimizer
|
59 |
+
:param betas: Adam optimizer betas
|
60 |
+
:param weight_decay: Adam optimizer weight decay param
|
61 |
+
:param with_cuda: traning with cuda
|
62 |
+
:param log_freq: logging frequency of the batch iteration
|
63 |
+
"""
|
64 |
+
self.device = "cpu"
|
65 |
+
self.model = bert
|
66 |
+
self.test_data = test_dataloader
|
67 |
+
|
68 |
+
self.log_freq = log_freq
|
69 |
+
self.workspace_name = workspace_name
|
70 |
+
# print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
71 |
+
|
72 |
+
def test(self, epoch):
|
73 |
+
self.iteration(epoch, self.test_data, train=False)
|
74 |
+
|
75 |
+
def iteration(self, epoch, data_loader, train=True):
|
76 |
+
"""
|
77 |
+
loop over the data_loader for training or testing
|
78 |
+
if on train status, backward operation is activated
|
79 |
+
and also auto save the model every peoch
|
80 |
+
|
81 |
+
:param epoch: current epoch index
|
82 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
83 |
+
:param train: boolean value of is train or test
|
84 |
+
:return: None
|
85 |
+
"""
|
86 |
+
str_code = "train" if train else "test"
|
87 |
+
|
88 |
+
# Setting the tqdm progress bar
|
89 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
90 |
+
desc="EP_%s:%d" % (str_code, epoch),
|
91 |
+
total=len(data_loader),
|
92 |
+
bar_format="{l_bar}{r_bar}")
|
93 |
+
|
94 |
+
avg_loss = 0.0
|
95 |
+
total_correct = 0
|
96 |
+
total_element = 0
|
97 |
+
|
98 |
+
plabels = []
|
99 |
+
tlabels = []
|
100 |
+
logits_list = []
|
101 |
+
labels_list = []
|
102 |
+
positive_class_probs = []
|
103 |
+
self.model.eval()
|
104 |
+
|
105 |
+
for i, data in data_iter:
|
106 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
107 |
+
|
108 |
+
with torch.no_grad():
|
109 |
+
h_rep, logits = self.model.forward(data["input"], data["segment_label"])
|
110 |
+
# print(logits, logits.shape)
|
111 |
+
logits_list.append(logits.cpu())
|
112 |
+
labels_list.append(data["label"].cpu())
|
113 |
+
|
114 |
+
probs = F.Softmax(dim=-1)(logits)
|
115 |
+
predicted_labels = torch.argmax(probs, dim=-1)
|
116 |
+
true_labels = torch.argmax(data["label"], dim=-1)
|
117 |
+
positive_class_probs.extend(probs[:, 1])
|
118 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
119 |
+
tlabels.extend(true_labels.cpu().numpy())
|
120 |
+
|
121 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels)
|
122 |
+
# Compare predicted labels to true labels and calculate accuracy
|
123 |
+
correct = (predicted_labels == true_labels).sum().item()
|
124 |
+
total_correct += correct
|
125 |
+
total_element += data["label"].nelement()
|
126 |
+
|
127 |
+
precisions = precision_score(tlabels, plabels, average="weighted")
|
128 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
129 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
130 |
+
accuracy = total_correct * 100.0 / total_element
|
131 |
+
auc_score = roc_auc_score(tlabels.cpu(), plabels.cpu())
|
132 |
+
|
133 |
+
final_msg = {
|
134 |
+
"epoch": f"EP{epoch}_{str_code}",
|
135 |
+
"accuracy": accuracy,
|
136 |
+
"avg_loss": avg_loss / len(data_iter),
|
137 |
+
"precisions": precisions,
|
138 |
+
"recalls": recalls,
|
139 |
+
"f1_scores": f1_scores
|
140 |
+
}
|
141 |
+
|
142 |
+
print(final_msg)
|
143 |
+
|
144 |
+
# print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element)
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
149 |
+
# print(device)
|
150 |
+
# is_model = torch.load("ratio_proportion_change4/output/bert_fine_tuned.IS.model.ep40", map_location=device)
|
151 |
+
# learned_parameters = model_ep0.state_dict()
|
152 |
+
|
153 |
+
# for param_name, param_tensor in learned_parameters.items():
|
154 |
+
# print(param_name)
|
155 |
+
# print(param_tensor)
|
156 |
+
# # print(model_ep0.state_dict())
|
157 |
+
# # model_ep0.add_module("out", nn.Linear(10,2))
|
158 |
+
# # print(model_ep0)
|
159 |
+
# seq_vocab = Vocab("pretraining/vocab_file.txt")
|
160 |
+
# seq_vocab.load_vocab()
|
161 |
+
# classifier = BERTForSequenceClassification(len(seq_vocab.vocab), 2)
|
162 |
+
|
163 |
+
|
164 |
+
parser = argparse.ArgumentParser()
|
165 |
+
|
166 |
+
parser.add_argument('-workspace_name', type=str, default="ratio_proportion_change3_1920")
|
167 |
+
# parser.add_argument("-t", "--test_dataset", type=str, default="finetuning/before_June/train_in.txt", help="test set for evaluate fine tune train set")
|
168 |
+
# parser.add_argument("-tlabel", "--test_label", type=str, default="finetuning/before_June/train_in_label.txt", help="test set for evaluate fine tune train set")
|
169 |
+
# ##### change Checkpoint
|
170 |
+
# parser.add_argument("-c", "--finetuned_bert_checkpoint", type=str, default="ratio_proportion_change3/output/before_June/bert_fine_tuned.FS.model.ep30", help="checkpoint of saved pretrained bert model")
|
171 |
+
# parser.add_argument("-v", "--vocab_path", type=str, default="pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
172 |
+
parser.add_argument("-t", "--test_dataset", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_test_dataset.csv", help="test set for evaluate fine tune train set")
|
173 |
+
parser.add_argument("-tlabel", "--test_label", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/test_infos_only.csv", help="test set for evaluate fine tune train set")
|
174 |
+
##### change Checkpoint
|
175 |
+
parser.add_argument("-c", "--finetuned_bert_checkpoint", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification/fine_tuned_model_2.pth", help="checkpoint of saved pretrained bert model")
|
176 |
+
parser.add_argument("-v", "--vocab_path", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
177 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
178 |
+
|
179 |
+
parser.add_argument("-hs", "--hidden", type=int, default=64, help="hidden size of transformer model")
|
180 |
+
parser.add_argument("-l", "--layers", type=int, default=4, help="number of layers")
|
181 |
+
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
|
182 |
+
parser.add_argument("-s", "--seq_len", type=int, default=100, help="maximum sequence length")
|
183 |
+
|
184 |
+
parser.add_argument("-b", "--batch_size", type=int, default=32, help="number of batch_size")
|
185 |
+
parser.add_argument("-e", "--epochs", type=int, default=1, help="number of epochs")
|
186 |
+
# Use 50 for pretrain, and 10 for fine tune
|
187 |
+
parser.add_argument("-w", "--num_workers", type=int, default=4, help="dataloader worker size")
|
188 |
+
|
189 |
+
# Later run with cuda
|
190 |
+
parser.add_argument("--with_cuda", type=bool, default=False, help="training with CUDA: true, or false")
|
191 |
+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
|
192 |
+
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
|
193 |
+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
|
194 |
+
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
|
195 |
+
|
196 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout of network")
|
197 |
+
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam")
|
198 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
|
199 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
|
200 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
|
201 |
+
|
202 |
+
args = parser.parse_args()
|
203 |
+
for k,v in vars(args).items():
|
204 |
+
if ('dataset' in k) or ('path' in k) or ('label' in k):
|
205 |
+
if v:
|
206 |
+
# setattr(args, f"{k}", args.workspace_name+"/"+v)
|
207 |
+
print(f"args.{k} : {getattr(args, f'{k}')}")
|
208 |
+
|
209 |
+
print("Loading Vocab", args.vocab_path)
|
210 |
+
vocab_obj = Vocab(args.vocab_path)
|
211 |
+
vocab_obj.load_vocab()
|
212 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
213 |
+
print("Loading Test Dataset", args.test_dataset)
|
214 |
+
test_dataset = TokenizerDataset(args.test_dataset, args.test_label, vocab_obj, seq_len=args.seq_len)
|
215 |
+
print("Creating Dataloader")
|
216 |
+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
217 |
+
bert = torch.load(args.finetuned_bert_checkpoint, map_location="cpu")
|
218 |
+
num_labels = 2
|
219 |
+
print(f"Number of Labels : {num_labels}")
|
220 |
+
print("Creating BERT Fine Tune Trainer")
|
221 |
+
trainer = BERTFineTuneTrainer1(bert, len(vocab_obj.vocab), train_dataloader=None, test_dataloader=test_data_loader,
|
222 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, workspace_name = args.workspace_name, num_labels=args.num_labels)
|
223 |
+
|
224 |
+
print("Testing Start....")
|
225 |
+
start_time = time.time()
|
226 |
+
for epoch in range(args.epochs):
|
227 |
+
trainer.test(epoch)
|
228 |
+
|
229 |
+
end_time = time.time()
|
230 |
+
|
231 |
+
print("Time Taken to fine tune dataset = ", end_time - start_time)
|
232 |
+
|
233 |
+
|
234 |
+
# bert/ratio_proportion_change3_2223/sch_largest_100-coded/output/Opts/bert_fine_tuned.model.ep22
|
visualization.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
#import matplotlib as mpl
|
3 |
+
#mpl.use('Agg')
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
import metrics
|
7 |
+
|
8 |
+
class ConfidenceHistogram(metrics.MaxProbCELoss):
|
9 |
+
|
10 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
11 |
+
super().loss(output, labels, n_bins, logits)
|
12 |
+
#scale each datapoint
|
13 |
+
n = len(labels)
|
14 |
+
w = np.ones(n)/n
|
15 |
+
|
16 |
+
plt.rcParams["font.family"] = "serif"
|
17 |
+
#size and axis limits
|
18 |
+
plt.figure(figsize=(4,3))
|
19 |
+
plt.xlim(0,1)
|
20 |
+
plt.ylim(0,1)
|
21 |
+
plt.xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
22 |
+
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
23 |
+
#plot grid
|
24 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
25 |
+
#plot histogram
|
26 |
+
plt.hist(self.confidences,n_bins,weights = w,color='b',range=(0.0,1.0),edgecolor = 'k')
|
27 |
+
|
28 |
+
#plot vertical dashed lines
|
29 |
+
acc = np.mean(self.accuracies)
|
30 |
+
conf = np.mean(self.confidences)
|
31 |
+
plt.axvline(x=acc, color='tab:grey', linestyle='--', linewidth = 3)
|
32 |
+
plt.axvline(x=conf, color='tab:grey', linestyle='--', linewidth = 3)
|
33 |
+
if acc > conf:
|
34 |
+
plt.text(acc+0.03,0.4,'Accuracy',rotation=90,fontsize=11)
|
35 |
+
plt.text(conf-0.07,0.4,'Avg. Confidence',rotation=90, fontsize=11)
|
36 |
+
else:
|
37 |
+
plt.text(acc-0.07,0.4,'Accuracy',rotation=90,fontsize=11)
|
38 |
+
plt.text(conf+0.03,0.4,'Avg. Confidence',rotation=90, fontsize=11)
|
39 |
+
|
40 |
+
plt.ylabel('% of Samples',fontsize=13)
|
41 |
+
plt.xlabel('Confidence',fontsize=13)
|
42 |
+
plt.tight_layout()
|
43 |
+
if title is not None:
|
44 |
+
plt.title(title,fontsize=16)
|
45 |
+
return plt
|
46 |
+
|
47 |
+
class ReliabilityDiagram(metrics.MaxProbCELoss):
|
48 |
+
|
49 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
50 |
+
super().loss(output, labels, n_bins, logits)
|
51 |
+
|
52 |
+
#computations
|
53 |
+
delta = 1.0/n_bins
|
54 |
+
x = np.arange(0,1,delta)
|
55 |
+
mid = np.linspace(delta/2,1-delta/2,n_bins)
|
56 |
+
error = np.concatenate((np.zeros(shape=7), np.abs(np.subtract(mid[7:],self.bin_acc[7:]))))
|
57 |
+
|
58 |
+
plt.rcParams["font.family"] = "serif"
|
59 |
+
#size and axis limits
|
60 |
+
plt.figure(figsize=(4,4))
|
61 |
+
plt.xlim(0,1)
|
62 |
+
plt.ylim(0,1)
|
63 |
+
#plot grid
|
64 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
65 |
+
#plot bars and identity line
|
66 |
+
plt.bar(x, self.bin_acc, color = 'b', width=delta,align='edge',edgecolor = 'k',label='Outputs',zorder=5)
|
67 |
+
plt.bar(x, error, bottom=np.minimum(self.bin_acc,mid), color = 'mistyrose', alpha=0.5, width=delta,align='edge',edgecolor = 'r',hatch='/',label='Gap',zorder=10)
|
68 |
+
ident = [0.0, 1.0]
|
69 |
+
plt.plot(ident,ident,linestyle='--',color='tab:grey',zorder=15)
|
70 |
+
#labels and legend
|
71 |
+
plt.ylabel('Accuracy',fontsize=13)
|
72 |
+
plt.xlabel('Confidence',fontsize=13)
|
73 |
+
plt.legend(loc='upper left',framealpha=1.0,fontsize='medium')
|
74 |
+
if title is not None:
|
75 |
+
plt.title(title,fontsize=16)
|
76 |
+
plt.tight_layout()
|
77 |
+
|
78 |
+
return plt
|