|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torch.optim import Adam, SGD |
|
from torch.utils.data import DataLoader |
|
import pickle |
|
|
|
from bert import BERT |
|
from seq_model import BERTSM |
|
from classifier_model import BERTForClassification |
|
from optim_schedule import ScheduledOptim |
|
|
|
import tqdm |
|
import sys |
|
|
|
import numpy as np |
|
import visualization |
|
|
|
from sklearn.metrics import precision_score, recall_score, f1_score |
|
|
|
class ECE(nn.Module): |
|
|
|
def __init__(self, n_bins=15): |
|
""" |
|
n_bins (int): number of confidence interval bins |
|
""" |
|
super(ECE, self).__init__() |
|
bin_boundaries = torch.linspace(0, 1, n_bins + 1) |
|
self.bin_lowers = bin_boundaries[:-1] |
|
self.bin_uppers = bin_boundaries[1:] |
|
|
|
def forward(self, logits, labels): |
|
softmaxes = F.softmax(logits, dim=1) |
|
confidences, predictions = torch.max(softmaxes, 1) |
|
labels = torch.argmax(labels,1) |
|
accuracies = predictions.eq(labels) |
|
|
|
ece = torch.zeros(1, device=logits.device) |
|
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers): |
|
|
|
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) |
|
prop_in_bin = in_bin.float().mean() |
|
if prop_in_bin.item() > 0: |
|
accuracy_in_bin = accuracies[in_bin].float().mean() |
|
avg_confidence_in_bin = confidences[in_bin].mean() |
|
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin |
|
|
|
return ece |
|
|
|
def accurate_nb(preds, labels): |
|
pred_flat = np.argmax(preds, axis=1).flatten() |
|
labels_flat = np.argmax(labels, axis=1).flatten() |
|
labels_flat = labels.flatten() |
|
return np.sum(pred_flat == labels_flat) |
|
|
|
class BERTTrainer: |
|
""" |
|
# Sequence.. |
|
|
|
BERTTrainer make the pretrained BERT model with two LM training method. |
|
|
|
1. Masked Language Model : 3.3.1 Task #1: Masked LM |
|
""" |
|
|
|
def __init__(self, bert: BERT, vocab_size: int, |
|
train_dataloader: DataLoader, test_dataloader: DataLoader = None, |
|
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, |
|
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False, |
|
workspace_name=None): |
|
""" |
|
:param bert: BERT model which you want to train |
|
:param vocab_size: total word vocab size |
|
:param train_dataloader: train dataset data loader |
|
:param test_dataloader: test dataset data loader [can be None] |
|
:param lr: learning rate of optimizer |
|
:param betas: Adam optimizer betas |
|
:param weight_decay: Adam optimizer weight decay param |
|
:param with_cuda: traning with cuda |
|
:param log_freq: logging frequency of the batch iteration |
|
""" |
|
|
|
|
|
cuda_condition = torch.cuda.is_available() and with_cuda |
|
self.device = torch.device("cuda:0" if cuda_condition else "cpu") |
|
print("Device used = ", self.device) |
|
|
|
|
|
self.bert = bert |
|
|
|
self.model = BERTSM(bert, vocab_size).to(self.device) |
|
|
|
|
|
if with_cuda and torch.cuda.device_count() > 1: |
|
print("Using %d GPUS for BERT" % torch.cuda.device_count()) |
|
self.model = nn.DataParallel(self.model, device_ids=cuda_devices) |
|
|
|
|
|
self.train_data = train_dataloader |
|
self.test_data = test_dataloader |
|
|
|
|
|
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) |
|
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) |
|
|
|
|
|
self.criterion = nn.NLLLoss(ignore_index=0) |
|
|
|
self.log_freq = log_freq |
|
self.same_student_prediction = same_student_prediction |
|
self.workspace_name = workspace_name |
|
self.save_model = False |
|
self.avg_loss = 10000 |
|
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) |
|
|
|
def train(self, epoch): |
|
self.iteration(epoch, self.train_data) |
|
|
|
def test(self, epoch): |
|
self.iteration(epoch, self.test_data, train=False) |
|
|
|
def iteration(self, epoch, data_loader, train=True): |
|
""" |
|
loop over the data_loader for training or testing |
|
if on train status, backward operation is activated |
|
and also auto save the model every peoch |
|
|
|
:param epoch: current epoch index |
|
:param data_loader: torch.utils.data.DataLoader for iteration |
|
:param train: boolean value of is train or test |
|
:return: None |
|
""" |
|
str_code = "train" if train else "test" |
|
code = "masked_prediction" if self.same_student_prediction else "masked" |
|
|
|
self.log_file = f"{self.workspace_name}/logs/{code}/log_{str_code}_pretrained.txt" |
|
bert_hidden_representations = [] |
|
if epoch == 0: |
|
f = open(self.log_file, 'w') |
|
f.close() |
|
if not train: |
|
self.avg_loss = 10000 |
|
|
|
data_iter = tqdm.tqdm(enumerate(data_loader), |
|
desc="EP_%s:%d" % (str_code, epoch), |
|
total=len(data_loader), |
|
bar_format="{l_bar}{r_bar}") |
|
|
|
avg_loss_mask = 0.0 |
|
total_correct_mask = 0 |
|
total_element_mask = 0 |
|
|
|
avg_loss_pred = 0.0 |
|
total_correct_pred = 0 |
|
total_element_pred = 0 |
|
|
|
avg_loss = 0.0 |
|
with open(self.log_file, 'a') as f: |
|
sys.stdout = f |
|
for i, data in data_iter: |
|
|
|
data = {key: value.to(self.device) for key, value in data.items()} |
|
|
|
|
|
|
|
if self.same_student_prediction: |
|
bert_hidden_rep, mask_lm_output, same_student_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction) |
|
else: |
|
bert_hidden_rep, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction) |
|
|
|
embeddings = [h for h in bert_hidden_rep.cpu().detach().numpy()] |
|
bert_hidden_representations.extend(embeddings) |
|
|
|
|
|
|
|
mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) |
|
|
|
|
|
if self.same_student_prediction: |
|
|
|
same_student_loss = self.criterion(same_student_output, data["is_same_student"]) |
|
loss = same_student_loss + mask_loss |
|
else: |
|
loss = mask_loss |
|
|
|
|
|
if train: |
|
self.optim_schedule.zero_grad() |
|
loss.backward() |
|
self.optim_schedule.step_and_update_lr() |
|
|
|
|
|
non_zero_mask = (data["bert_label"] != 0).float() |
|
predictions = torch.argmax(mask_lm_output, dim=-1) |
|
predicted_masked = predictions*non_zero_mask |
|
mask_correct = ((data["bert_label"] == predicted_masked)*non_zero_mask).sum().item() |
|
|
|
avg_loss_mask += loss.item() |
|
total_correct_mask += mask_correct |
|
total_element_mask += non_zero_mask.sum().item() |
|
|
|
post_fix = { |
|
"epoch": epoch, |
|
"iter": i, |
|
"avg_loss": avg_loss_mask / (i + 1), |
|
"avg_acc_mask": total_correct_mask / total_element_mask * 100, |
|
"loss": loss.item() |
|
} |
|
|
|
|
|
if self.same_student_prediction: |
|
correct = same_student_output.argmax(dim=-1).eq(data["is_same_student"]).sum().item() |
|
avg_loss_pred += loss.item() |
|
total_correct_pred += correct |
|
total_element_pred += data["is_same_student"].nelement() |
|
|
|
post_fix["avg_loss"] = avg_loss_pred / (i + 1) |
|
post_fix["avg_acc_pred"] = total_correct_pred / total_element_pred * 100 |
|
post_fix["loss"] = loss.item() |
|
|
|
avg_loss +=loss.item() |
|
|
|
if i % self.log_freq == 0: |
|
data_iter.write(str(post_fix)) |
|
|
|
|
|
|
|
|
|
final_msg = { |
|
"epoch": f"EP{epoch}_{str_code}", |
|
"avg_loss": avg_loss / len(data_iter), |
|
"total_masked_acc": total_correct_mask * 100.0 / total_element_mask |
|
} |
|
if self.same_student_prediction: |
|
final_msg["total_prediction_acc"] = total_correct_pred * 100.0 / total_element_pred |
|
|
|
print(final_msg) |
|
|
|
|
|
|
|
|
|
|
|
f.close() |
|
sys.stdout = sys.__stdout__ |
|
self.save_model = False |
|
if self.avg_loss > (avg_loss / len(data_iter)): |
|
self.save_model = True |
|
self.avg_loss = (avg_loss / len(data_iter)) |
|
|
|
|
|
|
|
|
|
|
|
def save(self, epoch, file_path="output/bert_trained.model"): |
|
""" |
|
Saving the current BERT model on file_path |
|
|
|
:param epoch: current epoch number |
|
:param file_path: model output path which gonna be file_path+"ep%d" % epoch |
|
:return: final_output_path |
|
""" |
|
output_path = file_path + ".ep%d" % epoch |
|
torch.save(self.bert.cpu(), output_path) |
|
self.bert.to(self.device) |
|
print("EP:%d Model Saved on:" % epoch, output_path) |
|
return output_path |
|
|
|
|
|
class BERTFineTuneTrainer: |
|
|
|
def __init__(self, bert: BERT, vocab_size: int, |
|
train_dataloader: DataLoader, test_dataloader: DataLoader = None, |
|
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, |
|
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2): |
|
""" |
|
:param bert: BERT model which you want to train |
|
:param vocab_size: total word vocab size |
|
:param train_dataloader: train dataset data loader |
|
:param test_dataloader: test dataset data loader [can be None] |
|
:param lr: learning rate of optimizer |
|
:param betas: Adam optimizer betas |
|
:param weight_decay: Adam optimizer weight decay param |
|
:param with_cuda: traning with cuda |
|
:param log_freq: logging frequency of the batch iteration |
|
""" |
|
|
|
|
|
cuda_condition = torch.cuda.is_available() and with_cuda |
|
self.device = torch.device("cuda:0" if cuda_condition else "cpu") |
|
print("Device used = ", self.device) |
|
|
|
|
|
self.bert = bert |
|
|
|
|
|
|
|
self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device) |
|
|
|
|
|
if with_cuda and torch.cuda.device_count() > 1: |
|
print("Using %d GPUS for BERT" % torch.cuda.device_count()) |
|
self.model = nn.DataParallel(self.model, device_ids=cuda_devices) |
|
|
|
|
|
self.train_data = train_dataloader |
|
self.test_data = test_dataloader |
|
|
|
self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-9) |
|
|
|
|
|
if num_labels == 1: |
|
self.criterion = nn.MSELoss() |
|
elif num_labels == 2: |
|
self.criterion = nn.CrossEntropyLoss() |
|
elif num_labels > 2: |
|
self.criterion = nn.BCEWithLogitsLoss() |
|
|
|
self.ece_criterion = ECE().to(self.device) |
|
|
|
self.log_freq = log_freq |
|
self.workspace_name = workspace_name |
|
self.save_model = False |
|
self.avg_loss = 10000 |
|
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) |
|
|
|
def train(self, epoch): |
|
self.iteration(epoch, self.train_data) |
|
|
|
def test(self, epoch): |
|
self.iteration(epoch, self.test_data, train=False) |
|
|
|
def iteration(self, epoch, data_loader, train=True): |
|
""" |
|
loop over the data_loader for training or testing |
|
if on train status, backward operation is activated |
|
and also auto save the model every peoch |
|
|
|
:param epoch: current epoch index |
|
:param data_loader: torch.utils.data.DataLoader for iteration |
|
:param train: boolean value of is train or test |
|
:return: None |
|
""" |
|
str_code = "train" if train else "test" |
|
|
|
self.log_file = f"{self.workspace_name}/logs/masked/log_{str_code}_FS_finetuned.txt" |
|
|
|
if epoch == 0: |
|
f = open(self.log_file, 'w') |
|
f.close() |
|
if not train: |
|
self.avg_loss = 10000 |
|
|
|
|
|
data_iter = tqdm.tqdm(enumerate(data_loader), |
|
desc="EP_%s:%d" % (str_code, epoch), |
|
total=len(data_loader), |
|
bar_format="{l_bar}{r_bar}") |
|
|
|
avg_loss = 0.0 |
|
total_correct = 0 |
|
total_element = 0 |
|
plabels = [] |
|
tlabels = [] |
|
eval_accurate_nb = 0 |
|
nb_eval_examples = 0 |
|
logits_list = [] |
|
labels_list = [] |
|
|
|
if train: |
|
self.model.train() |
|
else: |
|
self.model.eval() |
|
|
|
with open(self.log_file, 'a') as f: |
|
sys.stdout = f |
|
|
|
for i, data in data_iter: |
|
|
|
data = {key: value.to(self.device) for key, value in data.items()} |
|
if train: |
|
h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"]) |
|
else: |
|
with torch.no_grad(): |
|
h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"]) |
|
|
|
logits_list.append(logits.cpu()) |
|
labels_list.append(data["progress_status"].cpu()) |
|
|
|
|
|
|
|
progress_loss = self.criterion(logits, data["progress_status"]) |
|
loss = progress_loss |
|
|
|
if torch.cuda.device_count() > 1: |
|
loss = loss.mean() |
|
|
|
|
|
if train: |
|
self.optim.zero_grad() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
self.optim.step() |
|
|
|
|
|
|
|
probs = nn.LogSoftmax(dim=-1)(logits) |
|
predicted_labels = torch.argmax(probs, dim=-1) |
|
true_labels = torch.argmax(data["progress_status"], dim=-1) |
|
plabels.extend(predicted_labels.cpu().numpy()) |
|
tlabels.extend(true_labels.cpu().numpy()) |
|
|
|
|
|
|
|
correct = (predicted_labels == true_labels).sum().item() |
|
avg_loss += loss.item() |
|
total_correct += correct |
|
total_element += true_labels.nelement() |
|
|
|
if train: |
|
post_fix = { |
|
"epoch": epoch, |
|
"iter": i, |
|
"avg_loss": avg_loss / (i + 1), |
|
"avg_acc": total_correct / total_element * 100, |
|
"loss": loss.item() |
|
} |
|
else: |
|
logits = logits.detach().cpu().numpy() |
|
label_ids = data["progress_status"].to('cpu').numpy() |
|
tmp_eval_nb = accurate_nb(logits, label_ids) |
|
|
|
eval_accurate_nb += tmp_eval_nb |
|
nb_eval_examples += label_ids.shape[0] |
|
|
|
total_element += data["progress_status"].nelement() |
|
|
|
|
|
post_fix = { |
|
"epoch": epoch, |
|
"iter": i, |
|
"avg_loss": avg_loss / (i + 1), |
|
"avg_acc": tmp_eval_nb / total_element * 100, |
|
"loss": loss.item() |
|
} |
|
|
|
|
|
if i % self.log_freq == 0: |
|
data_iter.write(str(post_fix)) |
|
|
|
|
|
|
|
f1_scores = f1_score(plabels, tlabels, average="weighted") |
|
if train: |
|
final_msg = { |
|
"epoch": f"EP{epoch}_{str_code}", |
|
"avg_loss": avg_loss / len(data_iter), |
|
"total_acc": total_correct * 100.0 / total_element, |
|
|
|
|
|
"f1_scores": f1_scores |
|
} |
|
else: |
|
eval_accuracy = eval_accurate_nb/nb_eval_examples |
|
|
|
logits_ece = torch.cat(logits_list) |
|
labels_ece = torch.cat(labels_list) |
|
ece = self.ece_criterion(logits_ece, labels_ece).item() |
|
final_msg = { |
|
"epoch": f"EP{epoch}_{str_code}", |
|
"eval_accuracy": eval_accuracy, |
|
"ece": ece, |
|
"avg_loss": avg_loss / len(data_iter), |
|
|
|
|
|
"f1_scores": f1_scores |
|
} |
|
if self.save_model: |
|
conf_hist = visualization.ConfidenceHistogram() |
|
plt_test = conf_hist.plot(np.array(logits_ece), np.array(labels_ece), title= f"Confidence Histogram {epoch}") |
|
plt_test.savefig(f"{self.workspace_name}/plots/confidence_histogram/FS/conf_histogram_test_{epoch}.png",bbox_inches='tight') |
|
plt_test.close() |
|
|
|
rel_diagram = visualization.ReliabilityDiagram() |
|
plt_test_2 = rel_diagram.plot(np.array(logits_ece), np.array(labels_ece),title=f"Reliability Diagram {epoch}") |
|
plt_test_2.savefig(f"{self.workspace_name}/plots/confidence_histogram/FS/rel_diagram_test_{epoch}.png",bbox_inches='tight') |
|
plt_test_2.close() |
|
print(final_msg) |
|
|
|
|
|
f.close() |
|
sys.stdout = sys.__stdout__ |
|
if train: |
|
self.save_model = False |
|
if self.avg_loss > (avg_loss / len(data_iter)): |
|
self.save_model = True |
|
self.avg_loss = (avg_loss / len(data_iter)) |
|
|
|
|
|
|
|
|
|
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"): |
|
""" |
|
Saving the current BERT model on file_path |
|
|
|
:param epoch: current epoch number |
|
:param file_path: model output path which gonna be file_path+"ep%d" % epoch |
|
:return: final_output_path |
|
""" |
|
output_path = file_path + ".ep%d" % epoch |
|
torch.save(self.model.cpu(), output_path) |
|
self.model.to(self.device) |
|
print("EP:%d Model Saved on:" % epoch, output_path) |
|
return output_path |
|
|