|
import torch |
|
from torch.utils.data import Dataset |
|
import pandas as pd |
|
import numpy as np |
|
import tqdm |
|
import random |
|
from vocab import Vocab |
|
import pickle |
|
import copy |
|
from sklearn.preprocessing import OneHotEncoder |
|
|
|
class PretrainerDataset(Dataset): |
|
""" |
|
Class name: PretrainDataset |
|
|
|
""" |
|
def __init__(self, dataset_path, vocab, seq_len=30, select_next_seq= False): |
|
self.dataset_path = dataset_path |
|
self.vocab = vocab |
|
|
|
|
|
self.lines = [] |
|
self.index_documents = {} |
|
|
|
seq_len_list = [] |
|
with open(self.dataset_path, "r") as reader: |
|
i = 0 |
|
index = 0 |
|
self.index_documents[i] = [] |
|
for line in tqdm.tqdm(reader.readlines()): |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
i+=1 |
|
self.index_documents[i] = [] |
|
else: |
|
self.index_documents[i].append(index) |
|
self.lines.append(line.split()) |
|
len_line = len(line.split()) |
|
seq_len_list.append(len_line) |
|
index+=1 |
|
reader.close() |
|
print("Sequence Stats: ", len(seq_len_list), min(seq_len_list), max(seq_len_list), sum(seq_len_list)/len(seq_len_list)) |
|
print("Unique Sequences: ", len({tuple(ll) for ll in self.lines})) |
|
self.index_documents = {k:v for k,v in self.index_documents.items() if v} |
|
self.seq_len = seq_len |
|
self.max_mask_per_seq = 0.15 |
|
self.select_next_seq = select_next_seq |
|
print("Sequence length set at ", self.seq_len) |
|
print("select_next_seq: ", self.select_next_seq) |
|
print(len(self.index_documents)) |
|
|
|
|
|
def __len__(self): |
|
return len(self.lines) |
|
|
|
def __getitem__(self, item): |
|
token_a = self.lines[item] |
|
token_b = None |
|
is_same_student = None |
|
sa_masked = None |
|
sa_masked_label = None |
|
sb_masked = None |
|
sb_masked_label = None |
|
|
|
if self.select_next_seq: |
|
is_same_student, token_b = self.get_token_b(item) |
|
is_same_student = 1 if is_same_student else 0 |
|
token_a1, token_b1 = self.truncate_to_max_seq(token_a, token_b) |
|
sa_masked, sa_masked_label = self.random_mask_seq(token_a1) |
|
sb_masked, sb_masked_label = self.random_mask_seq(token_b1) |
|
else: |
|
token_a = token_a[:self.seq_len-2] |
|
sa_masked, sa_masked_label = self.random_mask_seq(token_a) |
|
|
|
s1 = ([self.vocab.vocab['[CLS]']] + sa_masked + [self.vocab.vocab['[SEP]']]) |
|
s1_label = ([self.vocab.vocab['[PAD]']] + sa_masked_label + [self.vocab.vocab['[PAD]']]) |
|
segment_label = [1 for _ in range(len(s1))] |
|
|
|
if self.select_next_seq: |
|
s1 = s1 + sb_masked + [self.vocab.vocab['[SEP]']] |
|
s1_label = s1_label + sb_masked_label + [self.vocab.vocab['[PAD]']] |
|
segment_label = segment_label + [2 for _ in range(len(sb_masked)+1)] |
|
|
|
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))] |
|
s1.extend(padding), s1_label.extend(padding), segment_label.extend(padding) |
|
|
|
output = {'bert_input': s1, |
|
'bert_label': s1_label, |
|
'segment_label': segment_label} |
|
|
|
if self.select_next_seq: |
|
output['is_same_student'] = is_same_student |
|
|
|
return {key: torch.tensor(value) for key, value in output.items()} |
|
|
|
def random_mask_seq(self, tokens): |
|
""" |
|
Input: original token seq |
|
Output: masked token seq, output label |
|
""" |
|
|
|
|
|
output_labels = [] |
|
output_tokens = copy.deepcopy(tokens) |
|
|
|
|
|
for i, token in enumerate(tokens): |
|
prob = random.random() |
|
if prob < 0.15: |
|
|
|
|
|
prob = random.random() |
|
if prob < 0.8: |
|
output_tokens[i] = self.vocab.vocab['[MASK]'] |
|
elif prob < 0.9: |
|
|
|
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1) |
|
else: |
|
|
|
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']) |
|
|
|
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])) |
|
|
|
else: |
|
|
|
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']) |
|
|
|
output_labels.append(self.vocab.vocab['[PAD]']) |
|
|
|
|
|
|
|
|
|
|
|
return output_tokens, output_labels |
|
|
|
def get_token_b(self, item): |
|
document_id = [k for k,v in self.index_documents.items() if item in v][0] |
|
random_document_id = document_id |
|
|
|
if random.random() < 0.5: |
|
document_ids = [k for k in self.index_documents.keys() if k != document_id] |
|
random_document_id = random.choice(document_ids) |
|
|
|
same_student = (random_document_id == document_id) |
|
|
|
nex_seq_list = self.index_documents.get(random_document_id) |
|
|
|
if same_student: |
|
if len(nex_seq_list) != 1: |
|
nex_seq_list = [v for v in nex_seq_list if v !=item] |
|
|
|
next_seq = random.choice(nex_seq_list) |
|
tokens = self.lines[next_seq] |
|
|
|
|
|
return same_student, tokens |
|
|
|
def truncate_to_max_seq(self, s1, s2): |
|
sa = copy.deepcopy(s1) |
|
sb = copy.deepcopy(s1) |
|
total_allowed_seq = self.seq_len - 3 |
|
|
|
while((len(sa)+len(sb)) > total_allowed_seq): |
|
if random.random() < 0.5: |
|
sa.pop() |
|
else: |
|
sb.pop() |
|
return sa, sb |
|
|
|
class TokenizerDataset(Dataset): |
|
""" |
|
Class name: TokenizerDataset |
|
Tokenize the data in the dataset |
|
|
|
""" |
|
def __init__(self, dataset_path, label_path, vocab, seq_len=30, train=True): |
|
self.dataset_path = dataset_path |
|
self.label_path = label_path |
|
self.vocab = vocab |
|
self.encoder = OneHotEncoder(sparse_output=False) |
|
|
|
|
|
self.lines = [] |
|
self.labels = [] |
|
self.labels = [] |
|
|
|
self.label_file = open(self.label_path, "r") |
|
for line in self.label_file: |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
self.labels.append(float(line)) |
|
self.label_file.close() |
|
labeler = np.unique(self.labels) |
|
self.encoder.fit(labeler.reshape(-1,1)) |
|
self.labels = self.encoder.transform(np.array(self.labels).reshape(-1,1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.file = open(self.dataset_path, "r") |
|
|
|
for line in self.file: |
|
if line: |
|
line = line.strip() |
|
if line: |
|
self.lines.append(line) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.file.close() |
|
|
|
self.len = len(self.lines) |
|
self.seq_len = seq_len |
|
|
|
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels)) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __getitem__(self, item): |
|
|
|
s1 = self.vocab.to_seq(self.lines[item], self.seq_len) |
|
s1_label = self.labels[item] |
|
segment_label = [1 for _ in range(len(s1))] |
|
|
|
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))] |
|
s1.extend(padding), segment_label.extend(padding) |
|
|
|
output = {'bert_input': s1, |
|
'progress_status': s1_label, |
|
'segment_label': segment_label} |
|
return {key: torch.tensor(value) for key, value in output.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|