|
from torch.utils.data import DataLoader |
|
import torch.nn as nn |
|
import torch |
|
import numpy |
|
|
|
import pickle |
|
import tqdm |
|
|
|
from bert import BERT |
|
from vocab import Vocab |
|
from dataset import TokenizerDataset |
|
import argparse |
|
from itertools import combinations |
|
|
|
def generate_subset(s): |
|
subsets = [] |
|
for r in range(len(s) + 1): |
|
combinations_result = combinations(s, r) |
|
if r==1: |
|
subsets.extend(([item] for sublist in combinations_result for item in sublist)) |
|
else: |
|
subsets.extend((list(sublist) for sublist in combinations_result)) |
|
subsets_dict = {i:s for i, s in enumerate(subsets)} |
|
return subsets_dict |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('-workspace_name', type=str, default=None) |
|
parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length") |
|
parser.add_argument('-pretrain', type=bool, default=False) |
|
parser.add_argument('-masked_pred', type=bool, default=False) |
|
parser.add_argument('-epoch', type=str, default=None) |
|
|
|
|
|
|
|
options = parser.parse_args() |
|
|
|
folder_path = options.workspace_name+"/" if options.workspace_name else "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab_path = f"{folder_path}check/pretraining/vocab.txt" |
|
|
|
|
|
|
|
print("Loading Vocab", vocab_path) |
|
vocab_obj = Vocab(vocab_path) |
|
vocab_obj.load_vocab() |
|
print("Vocab Size: ", len(vocab_obj.vocab)) |
|
|
|
|
|
|
|
|
|
|
|
if options.masked_pred: |
|
str_code = "masked_prediction" |
|
output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}" |
|
else: |
|
str_code = "masked" |
|
output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}" |
|
|
|
folder_path = folder_path+"check/" |
|
|
|
if options.pretrain: |
|
pretrain_file = f"{folder_path}pretraining/pretrain.txt" |
|
pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl" |
|
|
|
|
|
|
|
|
|
embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl" |
|
print("Loading Pretrain Dataset ", pretrain_file) |
|
pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len) |
|
|
|
print("Creating Dataloader") |
|
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4) |
|
else: |
|
val_file = f"{folder_path}pretraining/test.txt" |
|
val_label = f"{folder_path}pretraining/test_opt.txt" |
|
|
|
|
|
|
|
embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl" |
|
|
|
print("Loading Validation Dataset ", val_file) |
|
val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len) |
|
|
|
print("Creating Dataloader") |
|
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print(device) |
|
print("Load Pre-trained BERT model...") |
|
print(output_name) |
|
bert = torch.load(output_name, map_location=device) |
|
|
|
for param in bert.parameters(): |
|
param.requires_grad = False |
|
|
|
if options.pretrain: |
|
print("Pretrain-embeddings....") |
|
data_iter = tqdm.tqdm(enumerate(pretrain_data_loader), |
|
desc="pre-train", |
|
total=len(pretrain_data_loader), |
|
bar_format="{l_bar}{r_bar}") |
|
pretrain_embeddings = [] |
|
for i, data in data_iter: |
|
data = {key: value.to(device) for key, value in data.items()} |
|
hrep = bert(data["bert_input"], data["segment_label"]) |
|
|
|
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()] |
|
pretrain_embeddings.extend(embeddings) |
|
pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb")) |
|
|
|
|
|
else: |
|
print("Validation-embeddings....") |
|
data_iter = tqdm.tqdm(enumerate(val_data_loader), |
|
desc="validation", |
|
total=len(val_data_loader), |
|
bar_format="{l_bar}{r_bar}") |
|
val_embeddings = [] |
|
for i, data in data_iter: |
|
data = {key: value.to(device) for key, value in data.items()} |
|
hrep = bert(data["bert_input"], data["segment_label"]) |
|
|
|
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()] |
|
val_embeddings.extend(embeddings) |
|
pickle.dump(val_embeddings, open(embedding_file_path,"wb")) |
|
|
|
|
|
|