File size: 6,132 Bytes
6a34fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy

import pickle
import tqdm

from bert import BERT
from vocab import Vocab
from dataset import TokenizerDataset
import argparse
from itertools import combinations

def generate_subset(s):
    subsets = []
    for r in range(len(s) + 1):
        combinations_result = combinations(s, r)
        if r==1:
            subsets.extend(([item] for sublist in combinations_result for item in sublist))
        else:
            subsets.extend((list(sublist) for sublist in combinations_result))
    subsets_dict = {i:s for i, s in enumerate(subsets)}
    return subsets_dict

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument('-workspace_name', type=str, default=None)
    parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length")
    parser.add_argument('-pretrain', type=bool, default=False)
    parser.add_argument('-masked_pred', type=bool, default=False)
    parser.add_argument('-epoch', type=str, default=None)
    # parser.add_argument('-set_label', type=bool, default=False)
    # parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks')

    options = parser.parse_args()
    
    folder_path = options.workspace_name+"/" if options.workspace_name else ""
    
    # if options.set_label:
    #     label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'})
    #     pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
    # else:
    #     label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb"))
    # print(f"options.label _standard: {options.label_standard}")
    vocab_path = f"{folder_path}check/pretraining/vocab.txt"
    # vocab_path = f"{folder_path}pretraining/vocab.txt"

    
    print("Loading Vocab", vocab_path)
    vocab_obj = Vocab(vocab_path)
    vocab_obj.load_vocab()
    print("Vocab Size: ", len(vocab_obj.vocab))
    
    # label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb")))
    # label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'})
    # pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
    
    if options.masked_pred:
        str_code = "masked_prediction"
        output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}"
    else:
        str_code = "masked"
        output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}"
    
    folder_path = folder_path+"check/"
    # folder_path = folder_path
    if options.pretrain:
        pretrain_file = f"{folder_path}pretraining/pretrain.txt"
        pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl"
        
        # pretrain_file = f"{folder_path}finetuning/train.txt"
        # pretrain_label = f"{folder_path}finetuning/train_label.txt"
        
        embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl"
        print("Loading Pretrain Dataset ", pretrain_file)
        pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len)
        
        print("Creating Dataloader")
        pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4)
    else:
        val_file = f"{folder_path}pretraining/test.txt"
        val_label = f"{folder_path}pretraining/test_opt.txt"
            
#         val_file = f"{folder_path}finetuning/test.txt"
#         val_label = f"{folder_path}finetuning/test_label.txt"
        embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl"

        print("Loading Validation Dataset ", val_file)
        val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len)

        print("Creating Dataloader")
        val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    print("Load Pre-trained BERT model...")
    print(output_name)
    bert = torch.load(output_name, map_location=device)  
#     learned_parameters = model_ep0.state_dict()
    for param in bert.parameters():
        param.requires_grad = False

    if options.pretrain:
        print("Pretrain-embeddings....")
        data_iter = tqdm.tqdm(enumerate(pretrain_data_loader),
                                  desc="pre-train",
                                  total=len(pretrain_data_loader),
                                  bar_format="{l_bar}{r_bar}")
        pretrain_embeddings = []
        for i, data in data_iter:
            data = {key: value.to(device) for key, value in data.items()}
            hrep = bert(data["bert_input"], data["segment_label"]) 
            # print(hrep[:,0].cpu().detach().numpy())
            embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
            pretrain_embeddings.extend(embeddings)
        pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb"))
        # pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb"))

    else:
        print("Validation-embeddings....")
        data_iter = tqdm.tqdm(enumerate(val_data_loader),
                                  desc="validation",
                                  total=len(val_data_loader),
                                  bar_format="{l_bar}{r_bar}")
        val_embeddings = []
        for i, data in data_iter:
            data = {key: value.to(device) for key, value in data.items()}
            hrep = bert(data["bert_input"], data["segment_label"]) 
            # print(,hrep[:,0].shape)
            embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
            val_embeddings.extend(embeddings)
        pickle.dump(val_embeddings, open(embedding_file_path,"wb"))
        # pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb"))