astra / src /dataset.py
suryadev1's picture
v1
6a34fd4
raw
history blame
12 kB
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import tqdm
import random
from vocab import Vocab
import pickle
import copy
from sklearn.preprocessing import OneHotEncoder
class PretrainerDataset(Dataset):
"""
Class name: PretrainDataset
"""
def __init__(self, dataset_path, vocab, seq_len=30, select_next_seq= False):
self.dataset_path = dataset_path
self.vocab = vocab # Vocab object
# Related to input dataset file
self.lines = []
self.index_documents = {}
seq_len_list = []
with open(self.dataset_path, "r") as reader:
i = 0
index = 0
self.index_documents[i] = []
for line in tqdm.tqdm(reader.readlines()):
if line:
line = line.strip()
if not line:
i+=1
self.index_documents[i] = []
else:
self.index_documents[i].append(index)
self.lines.append(line.split())
len_line = len(line.split())
seq_len_list.append(len_line)
index+=1
reader.close()
print("Sequence Stats: ", len(seq_len_list), min(seq_len_list), max(seq_len_list), sum(seq_len_list)/len(seq_len_list))
print("Unique Sequences: ", len({tuple(ll) for ll in self.lines}))
self.index_documents = {k:v for k,v in self.index_documents.items() if v}
self.seq_len = seq_len
self.max_mask_per_seq = 0.15
self.select_next_seq = select_next_seq
print("Sequence length set at ", self.seq_len)
print("select_next_seq: ", self.select_next_seq)
print(len(self.index_documents))
def __len__(self):
return len(self.lines)
def __getitem__(self, item):
token_a = self.lines[item]
token_b = None
is_same_student = None
sa_masked = None
sa_masked_label = None
sb_masked = None
sb_masked_label = None
if self.select_next_seq:
is_same_student, token_b = self.get_token_b(item)
is_same_student = 1 if is_same_student else 0
token_a1, token_b1 = self.truncate_to_max_seq(token_a, token_b)
sa_masked, sa_masked_label = self.random_mask_seq(token_a1)
sb_masked, sb_masked_label = self.random_mask_seq(token_b1)
else:
token_a = token_a[:self.seq_len-2]
sa_masked, sa_masked_label = self.random_mask_seq(token_a)
s1 = ([self.vocab.vocab['[CLS]']] + sa_masked + [self.vocab.vocab['[SEP]']])
s1_label = ([self.vocab.vocab['[PAD]']] + sa_masked_label + [self.vocab.vocab['[PAD]']])
segment_label = [1 for _ in range(len(s1))]
if self.select_next_seq:
s1 = s1 + sb_masked + [self.vocab.vocab['[SEP]']]
s1_label = s1_label + sb_masked_label + [self.vocab.vocab['[PAD]']]
segment_label = segment_label + [2 for _ in range(len(sb_masked)+1)]
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
s1.extend(padding), s1_label.extend(padding), segment_label.extend(padding)
output = {'bert_input': s1,
'bert_label': s1_label,
'segment_label': segment_label}
if self.select_next_seq:
output['is_same_student'] = is_same_student
# print(item, len(s1), len(s1_label), len(segment_label))
return {key: torch.tensor(value) for key, value in output.items()}
def random_mask_seq(self, tokens):
"""
Input: original token seq
Output: masked token seq, output label
"""
# masked_pos_label = {}
output_labels = []
output_tokens = copy.deepcopy(tokens)
# while(len(label_tokens) < self.max_mask_per_seq*len(tokens)):
for i, token in enumerate(tokens):
prob = random.random()
if prob < 0.15:
# chooses 15% of token positions at random
# prob /= 0.15
prob = random.random()
if prob < 0.8: #[MASK] token 80% of the time
output_tokens[i] = self.vocab.vocab['[MASK]']
elif prob < 0.9: # a random token 10% of the time
# print(".......0.8-0.9......")
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
else: # the unchanged i-th token 10% of the time
# print(".......unchanged......")
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
# True Label
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
# masked_pos_label[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
else:
# i-th token with original value
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
# Padded label
output_labels.append(self.vocab.vocab['[PAD]'])
# label_position = []
# label_tokens = []
# for k, v in masked_pos_label.items():
# label_position.append(k)
# label_tokens.append(v)
return output_tokens, output_labels
def get_token_b(self, item):
document_id = [k for k,v in self.index_documents.items() if item in v][0]
random_document_id = document_id
if random.random() < 0.5:
document_ids = [k for k in self.index_documents.keys() if k != document_id]
random_document_id = random.choice(document_ids)
same_student = (random_document_id == document_id)
nex_seq_list = self.index_documents.get(random_document_id)
if same_student:
if len(nex_seq_list) != 1:
nex_seq_list = [v for v in nex_seq_list if v !=item]
next_seq = random.choice(nex_seq_list)
tokens = self.lines[next_seq]
# print(f"item = {item}, tokens: {tokens}")
# print(f"item={item}, next={next_seq}, same_student = {same_student}, {document_id} == {random_document_id}, b. {tokens}")
return same_student, tokens
def truncate_to_max_seq(self, s1, s2):
sa = copy.deepcopy(s1)
sb = copy.deepcopy(s1)
total_allowed_seq = self.seq_len - 3
while((len(sa)+len(sb)) > total_allowed_seq):
if random.random() < 0.5:
sa.pop()
else:
sb.pop()
return sa, sb
class TokenizerDataset(Dataset):
"""
Class name: TokenizerDataset
Tokenize the data in the dataset
"""
def __init__(self, dataset_path, label_path, vocab, seq_len=30, train=True):
self.dataset_path = dataset_path
self.label_path = label_path
self.vocab = vocab # Vocab object
self.encoder = OneHotEncoder(sparse_output=False)
# Related to input dataset file
self.lines = []
self.labels = []
self.labels = []
self.label_file = open(self.label_path, "r")
for line in self.label_file:
if line:
line = line.strip()
if not line:
continue
self.labels.append(float(line))
self.label_file.close()
labeler = np.unique(self.labels)
self.encoder.fit(labeler.reshape(-1,1))
self.labels = self.encoder.transform(np.array(self.labels).reshape(-1,1))
# print(f"labels: {self.labels}")
# info_file_name = self.dataset_path.split('.')
# info_file_name = info_file_name[0]+"_info."+info_file_name[1]
# progress = []
# with open(info_file_name, "r") as f:
# for line in f:
# if line:
# line = line.strip()
# if not line:
# continue
# line = line.split(",")[0]
# pstat = 1 if line == "GRADUATED" else 0
# progress.append(pstat)
# f.close()
# indices_of_grad = np.where(np.array(progress) == 1)[0]
# indices_of_prom = np.where(np.array(progress) == 0)[0]
# indices_of_zeros = np.where(np.array(labels) == 0)[0]
# indices_of_ones = np.where(np.array(labels) == 1)[0]
# number_of_items = min(len(indices_of_zeros), len(indices_of_ones))
# # number_of_items = min(len(indices_of_grad), len(indices_of_prom))
# print(number_of_items)
# indices_of_zeros = indices_of_zeros[:number_of_items]
# indices_of_ones = indices_of_ones[:number_of_items]
# print(indices_of_zeros)
# print(indices_of_ones)
# indices_of_grad = indices_of_grad[:number_of_items]
# indices_of_prom = indices_of_prom[:number_of_items]
# print(indices_of_grad)
# print(indices_of_prom)
self.file = open(self.dataset_path, "r")
# index = 0
for line in self.file:
if line:
line = line.strip()
if line:
self.lines.append(line)
# if train:
# if index in indices_of_zeros:
# # if index in indices_of_prom:
# self.lines.append(line)
# self.labels.append(0)
# if index in indices_of_ones:
# # if index in indices_of_grad:
# self.lines.append(line)
# self.labels.append(1)
# else:
# self.lines.append(line)
# self.labels.append(labels[index])
# self.labels.append(progress[index])
# index += 1
self.file.close()
self.len = len(self.lines)
self.seq_len = seq_len
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels))
def __len__(self):
return self.len
def __getitem__(self, item):
s1 = self.vocab.to_seq(self.lines[item], self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
s1_label = self.labels[item]
segment_label = [1 for _ in range(len(s1))]
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
s1.extend(padding), segment_label.extend(padding)
output = {'bert_input': s1,
'progress_status': s1_label,
'segment_label': segment_label}
return {key: torch.tensor(value) for key, value in output.items()}
# if __name__ == "__main__":
# # import pickle
# # k = pickle.load(open("dataset/CL4999_1920/unique_steps_list.pkl","rb"))
# # print(k)
# vocab_obj = Vocab("pretraining/vocab.txt")
# vocab_obj.load_vocab()
# datasetTrain = PretrainerDataset("pretraining/pretrain.txt", vocab_obj)
# print(datasetTrain, len(datasetTrain))#, datasetTrain.documents_index)
# print(datasetTrain[len(datasetTrain)-1])
# for i, d in enumerate(datasetTrain):
# print(d.items())
# break
# fine_tune = TokenizerDataset("finetuning/finetune.txt", "finetuning/finetune_label.txt", vocab_obj)
# print(fine_tune)
# print(fine_tune[len(fine_tune)-1])
# print(fine_tune[random.randint(0, len(fine_tune))])
# for i, d in enumerate(fine_tune):
# print(d.items())
# break