File size: 6,132 Bytes
6a34fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy
import pickle
import tqdm
from bert import BERT
from vocab import Vocab
from dataset import TokenizerDataset
import argparse
from itertools import combinations
def generate_subset(s):
subsets = []
for r in range(len(s) + 1):
combinations_result = combinations(s, r)
if r==1:
subsets.extend(([item] for sublist in combinations_result for item in sublist))
else:
subsets.extend((list(sublist) for sublist in combinations_result))
subsets_dict = {i:s for i, s in enumerate(subsets)}
return subsets_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-workspace_name', type=str, default=None)
parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length")
parser.add_argument('-pretrain', type=bool, default=False)
parser.add_argument('-masked_pred', type=bool, default=False)
parser.add_argument('-epoch', type=str, default=None)
# parser.add_argument('-set_label', type=bool, default=False)
# parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks')
options = parser.parse_args()
folder_path = options.workspace_name+"/" if options.workspace_name else ""
# if options.set_label:
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'})
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
# else:
# label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb"))
# print(f"options.label _standard: {options.label_standard}")
vocab_path = f"{folder_path}check/pretraining/vocab.txt"
# vocab_path = f"{folder_path}pretraining/vocab.txt"
print("Loading Vocab", vocab_path)
vocab_obj = Vocab(vocab_path)
vocab_obj.load_vocab()
print("Vocab Size: ", len(vocab_obj.vocab))
# label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb")))
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'})
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
if options.masked_pred:
str_code = "masked_prediction"
output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}"
else:
str_code = "masked"
output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}"
folder_path = folder_path+"check/"
# folder_path = folder_path
if options.pretrain:
pretrain_file = f"{folder_path}pretraining/pretrain.txt"
pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl"
# pretrain_file = f"{folder_path}finetuning/train.txt"
# pretrain_label = f"{folder_path}finetuning/train_label.txt"
embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl"
print("Loading Pretrain Dataset ", pretrain_file)
pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len)
print("Creating Dataloader")
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4)
else:
val_file = f"{folder_path}pretraining/test.txt"
val_label = f"{folder_path}pretraining/test_opt.txt"
# val_file = f"{folder_path}finetuning/test.txt"
# val_label = f"{folder_path}finetuning/test_label.txt"
embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl"
print("Loading Validation Dataset ", val_file)
val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len)
print("Creating Dataloader")
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("Load Pre-trained BERT model...")
print(output_name)
bert = torch.load(output_name, map_location=device)
# learned_parameters = model_ep0.state_dict()
for param in bert.parameters():
param.requires_grad = False
if options.pretrain:
print("Pretrain-embeddings....")
data_iter = tqdm.tqdm(enumerate(pretrain_data_loader),
desc="pre-train",
total=len(pretrain_data_loader),
bar_format="{l_bar}{r_bar}")
pretrain_embeddings = []
for i, data in data_iter:
data = {key: value.to(device) for key, value in data.items()}
hrep = bert(data["bert_input"], data["segment_label"])
# print(hrep[:,0].cpu().detach().numpy())
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
pretrain_embeddings.extend(embeddings)
pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb"))
# pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb"))
else:
print("Validation-embeddings....")
data_iter = tqdm.tqdm(enumerate(val_data_loader),
desc="validation",
total=len(val_data_loader),
bar_format="{l_bar}{r_bar}")
val_embeddings = []
for i, data in data_iter:
data = {key: value.to(device) for key, value in data.items()}
hrep = bert(data["bert_input"], data["segment_label"])
# print(,hrep[:,0].shape)
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
val_embeddings.extend(embeddings)
pickle.dump(val_embeddings, open(embedding_file_path,"wb"))
# pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb"))
|