|
import torch |
|
from torch.utils.data import Dataset |
|
import pandas as pd |
|
import numpy as np |
|
import tqdm |
|
import random |
|
from .vocab import Vocab |
|
import pickle |
|
import copy |
|
|
|
|
|
class PretrainerDataset(Dataset): |
|
""" |
|
Class name: PretrainDataset |
|
|
|
""" |
|
def __init__(self, dataset_path, vocab, seq_len=30, max_mask=0.15): |
|
self.dataset_path = dataset_path |
|
self.vocab = vocab |
|
|
|
|
|
self.lines = [] |
|
self.index_documents = {} |
|
|
|
seq_len_list = [] |
|
with open(self.dataset_path, "r") as reader: |
|
i = 0 |
|
index = 0 |
|
self.index_documents[i] = [] |
|
for line in tqdm.tqdm(reader.readlines()): |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
i+=1 |
|
self.index_documents[i] = [] |
|
else: |
|
self.index_documents[i].append(index) |
|
self.lines.append(line.split("\t")) |
|
len_line = len(line.split("\t")) |
|
seq_len_list.append(len_line) |
|
index+=1 |
|
reader.close() |
|
print("Sequence Stats: len: %s, min: %s, max: %s, average: %s"% (len(seq_len_list), |
|
min(seq_len_list), max(seq_len_list), sum(seq_len_list)/len(seq_len_list))) |
|
print("Unique Sequences: ", len({tuple(ll) for ll in self.lines})) |
|
self.index_documents = {k:v for k,v in self.index_documents.items() if v} |
|
print(len(self.index_documents)) |
|
self.seq_len = seq_len |
|
print("Sequence length set at: ", self.seq_len) |
|
self.max_mask = max_mask |
|
print("% of input tokens selected for masking : ",self.max_mask) |
|
|
|
|
|
def __len__(self): |
|
return len(self.lines) |
|
|
|
def __getitem__(self, item): |
|
token_a = self.lines[item] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_a = token_a[:self.seq_len-2] |
|
sa_masked, sa_masked_label, sa_masked_pos = self.random_mask_seq(token_a) |
|
|
|
s1 = ([self.vocab.vocab['[CLS]']] + sa_masked + [self.vocab.vocab['[SEP]']]) |
|
s1_label = ([self.vocab.vocab['[PAD]']] + sa_masked_label + [self.vocab.vocab['[PAD]']]) |
|
segment_label = [1 for _ in range(len(s1))] |
|
masked_pos = ([0] + sa_masked_pos + [0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))] |
|
s1.extend(padding) |
|
s1_label.extend(padding) |
|
segment_label.extend(padding) |
|
masked_pos.extend(padding) |
|
|
|
output = {'bert_input': s1, |
|
'bert_label': s1_label, |
|
'segment_label': segment_label, |
|
'masked_pos': masked_pos} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {key: torch.tensor(value) for key, value in output.items()} |
|
|
|
def random_mask_seq(self, tokens): |
|
""" |
|
Input: original token seq |
|
Output: masked token seq, output label |
|
""" |
|
|
|
masked_pos = [] |
|
output_labels = [] |
|
output_tokens = copy.deepcopy(tokens) |
|
opt_step = False |
|
for i, token in enumerate(tokens): |
|
if token in ['OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor', 'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow', 'ThirdRow']: |
|
opt_step = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prob = random.random() |
|
if prob < self.max_mask: |
|
|
|
|
|
prob = random.random() |
|
if prob < 0.8: |
|
output_tokens[i] = self.vocab.vocab['[MASK]'] |
|
masked_pos.append(1) |
|
elif prob < 0.9: |
|
|
|
if opt_step: |
|
output_tokens[i] = random.choice([7,8,9,11,12,13,14,15,16,22,23,24,25,26,27,30,31,32]) |
|
opt_step = False |
|
else: |
|
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1) |
|
masked_pos.append(1) |
|
else: |
|
|
|
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']) |
|
masked_pos.append(0) |
|
|
|
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])) |
|
|
|
else: |
|
|
|
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']) |
|
|
|
output_labels.append(self.vocab.vocab['[PAD]']) |
|
masked_pos.append(0) |
|
|
|
|
|
|
|
|
|
|
|
return output_tokens, output_labels, masked_pos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizerDataset(Dataset): |
|
""" |
|
Class name: TokenizerDataset |
|
Tokenize the data in the dataset |
|
|
|
""" |
|
def __init__(self, dataset_path, label_path, vocab, seq_len=30): |
|
self.dataset_path = dataset_path |
|
self.label_path = label_path |
|
self.vocab = vocab |
|
|
|
|
|
|
|
self.lines = [] |
|
self.labels = [] |
|
self.feats = [] |
|
if self.label_path: |
|
self.label_file = open(self.label_path, "r") |
|
for line in self.label_file: |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
self.labels.append(int(line)) |
|
self.label_file.close() |
|
|
|
|
|
try: |
|
j = 0 |
|
dataset_info_file = open(self.label_path.replace("label", "info"), "r") |
|
for line in dataset_info_file: |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
feat_vec = [float(i) for i in line.split(",")[-3].split("\t")] |
|
feat2 = [float(i) for i in line.split(",")[-2].split("\t")] |
|
feat_vec.extend(feat2[1:]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if j == 0: |
|
print(len(feat_vec)) |
|
j+=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.feats.append(feat_vec) |
|
dataset_info_file.close() |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
|
|
|
|
|
|
self.file = open(self.dataset_path, "r") |
|
for line in self.file: |
|
if line: |
|
line = line.strip() |
|
if line: |
|
self.lines.append(line) |
|
self.file.close() |
|
|
|
self.len = len(self.lines) |
|
self.seq_len = seq_len |
|
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels) if self.label_path else 0) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __getitem__(self, item): |
|
org_line = self.lines[item].split("\t") |
|
dup_line = [] |
|
opt = False |
|
for l in org_line: |
|
if l in ["OptionalTask_1", "EquationAnswer", "NumeratorFactor", "DenominatorFactor", "OptionalTask_2", "FirstRow1:1", "FirstRow1:2", "FirstRow2:1", "FirstRow2:2", "SecondRow", "ThirdRow"]: |
|
opt = True |
|
if opt and 'FinalAnswer-' in l: |
|
dup_line.append('[UNK]') |
|
else: |
|
dup_line.append(l) |
|
dup_line = "\t".join(dup_line) |
|
|
|
s1 = self.vocab.to_seq(dup_line, self.seq_len) |
|
s1_label = self.labels[item] if self.label_path else 0 |
|
segment_label = [1 for _ in range(len(s1))] |
|
s1_feat = self.feats[item] if len(self.feats)>0 else 0 |
|
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))] |
|
s1.extend(padding), segment_label.extend(padding) |
|
|
|
output = {'input': s1, |
|
'label': s1_label, |
|
'feat': s1_feat, |
|
'segment_label': segment_label} |
|
return {key: torch.tensor(value) for key, value in output.items()} |
|
|
|
|
|
class TokenizerDatasetForCalibration(Dataset): |
|
""" |
|
Class name: TokenizerDataset |
|
Tokenize the data in the dataset |
|
|
|
""" |
|
def __init__(self, dataset_path, label_path, vocab, seq_len=30): |
|
self.dataset_path = dataset_path |
|
self.label_path = label_path |
|
self.vocab = vocab |
|
|
|
|
|
|
|
self.lines = [] |
|
self.labels = [] |
|
self.feats = [] |
|
if self.label_path: |
|
self.label_file = open(self.label_path, "r") |
|
for line in self.label_file: |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
self.labels.append(int(line)) |
|
self.label_file.close() |
|
|
|
|
|
try: |
|
j = 0 |
|
dataset_info_file = open(self.label_path.replace("label", "info"), "r") |
|
for line in dataset_info_file: |
|
if line: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
feat_vec = [float(i) for i in line.split(",")[-3].split("\t")] |
|
feat2 = [float(i) for i in line.split(",")[-2].split("\t")] |
|
feat_vec.extend(feat2[1:]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if j == 0: |
|
print(len(feat_vec)) |
|
j+=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.feats.append(feat_vec) |
|
dataset_info_file.close() |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
|
|
|
|
|
|
self.file = open(self.dataset_path, "r") |
|
for line in self.file: |
|
if line: |
|
line = line.strip() |
|
if line: |
|
self.lines.append(line) |
|
self.file.close() |
|
|
|
self.len = len(self.lines) |
|
self.seq_len = seq_len |
|
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels) if self.label_path else 0) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __getitem__(self, item): |
|
org_line = self.lines[item].split("\t") |
|
dup_line = [] |
|
opt = False |
|
for l in org_line: |
|
if l in ["OptionalTask_1", "EquationAnswer", "NumeratorFactor", "DenominatorFactor", "OptionalTask_2", "FirstRow1:1", "FirstRow1:2", "FirstRow2:1", "FirstRow2:2", "SecondRow", "ThirdRow"]: |
|
opt = True |
|
if opt and 'FinalAnswer-' in l: |
|
dup_line.append('[UNK]') |
|
else: |
|
dup_line.append(l) |
|
dup_line = "\t".join(dup_line) |
|
|
|
s1 = self.vocab.to_seq(dup_line, self.seq_len) |
|
s1_label = self.labels[item] if self.label_path else 0 |
|
segment_label = [1 for _ in range(len(s1))] |
|
s1_feat = self.feats[item] if len(self.feats)>0 else 0 |
|
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))] |
|
s1.extend(padding), segment_label.extend(padding) |
|
|
|
output = {'input': s1, |
|
'label': s1_label, |
|
'feat': s1_feat, |
|
'segment_label': segment_label} |
|
return ({key: torch.tensor(value) for key, value in output.items()}, s1_label) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|