claim_detection / code /items_dataset.py
JasonLiao's picture
Upload 7 files
9fdc3cc
raw
history blame contribute delete
No virus
6.09 kB
import torch
from torch.utils.data import Dataset
from args import args
class items_dataset(Dataset):
def __init__(self, tokenizer, data_set, label_dict, stride=0, max_length=args.max_length):
self.data_set = data_set
self.tokenizer = tokenizer
self.label_dict = label_dict
self.max_length = max_length
self.encode_max_length = max_length-2 #[CLS] [SEP]
self.batch_max_lenght = max_length
self.stride = stride
def __getitem__(self, index):
result = self.data_set[index]
return result
def __len__(self):
return len(self.data_set)
def create_label_list(self, span_label, max_len):
#ans = []
table = torch.zeros(max_len)
for start, end in span_label:
table[start:end] = 2 #"I"
table[start] = 1 #"B"
"""
for label in table.tolist():
if label == 0:
ans.append("O")
elif label == 1:
ans.append("B")
elif label == 2:
ans.append("I")
else:
print("error")
"""
return table
def encode_lable(self, encoded, batch_table):
batch_encode_seq_lens = []
sample_mapping = encoded["overflow_to_sample_mapping"]
offset_mapping = encoded["offset_mapping"]
encoded_label = torch.zeros(len(sample_mapping) ,self.encode_max_length, dtype=torch.long)
for id_in_batch in range(len(sample_mapping)):
encode_len=0
table = batch_table[sample_mapping[id_in_batch]]
for i in range(self.max_length):
char_start, char_end = offset_mapping[id_in_batch][i]
# ignore [CLS], [SEP] token
if char_start!=0 or char_end!=0:
encode_len+=1
#print(encoded_label.shape, table.shape)
encoded_label[id_in_batch][i-1] = table[char_start].long()
batch_encode_seq_lens.append(encode_len)
return encoded_label, batch_encode_seq_lens
def create_crf_mask(self, batch_encode_seq_lens):
mask = torch.zeros(len(batch_encode_seq_lens), self.encode_max_length, dtype=torch.bool)
#print(len(batch_table), len(batch_lens), seq_lens, batch_lens)
for i, batch_len in enumerate(batch_encode_seq_lens):
mask[i][:batch_len]=True
return mask
def boundary_encoded(self, encodings, batch_boundary):
batch_boundary_encoded = []
for batch_id, span_labels in enumerate(batch_boundary):
boundary_encoded = []
end = 0
for boundary in span_labels:
end += boundary
encoded_end = encodings[batch_id].char_to_token(end-1)
#
tmp_end = end
while encoded_end==None and tmp_end>0:
tmp_end-=1
encoded_end = encodings[batch_id].char_to_token(tmp_end-1)
if end!=None: encoded_end+=1
if encoded_end>self.encode_max_length:
boundary_encoded.append(self.encode_max_length)
break
else:
boundary_encoded.append(encoded_end)
for i in range(len(boundary_encoded)-1, 0, -1):
boundary_encoded[i]=boundary_encoded[i]-boundary_encoded[i-1]
batch_boundary_encoded.append(boundary_encoded)
return batch_boundary_encoded
def cal_agreement_span(self, agreement_table, min_agree=2, max_agree=3):
"""
find the spans from agreement table
"""
ans_span=[]
start, end =(0, 0)
pre_p = agreement_table[0]
for i, word_agreement in enumerate(agreement_table):
curr_p = word_agreement
if curr_p != pre_p:
if start != end: ans_span.append([start, end])
start=i
end=i
pre_p = curr_p
if word_agreement<min_agree:
start+=1
if word_agreement<=max_agree:
end+=1
#print([start, end])
pre_p = curr_p
if start != end: ans_span.append([start, end])
#print(ans_span)
if len(ans_span)<=1 or min_agree == max_agree:
return ans_span
#span 合併
span_concate = []
start, end = [ans_span[0][0], ans_span[0][1]]
for span_id in range(1, len(ans_span)):
if ans_span[span_id-1][1]==ans_span[span_id][0]:
ans_span[span_id]=[ans_span[span_id-1][0], ans_span[span_id][1]]
if span_id==len(ans_span)-1: span_concate.append(ans_span[span_id])
#span_concate.append()
elif span_id==len(ans_span)-1:
span_concate.extend([ans_span[span_id-1], ans_span[span_id]])
else:
span_concate.append(ans_span[span_id-1])
return span_concate
def collate_fn(self, batch_sample):
batch_text = []
batch_table = []
batch_span_label= []
seq_lens = []
for sample in batch_sample:
batch_text.append(sample['original_text'])
batch_table.append(self.create_label_list(sample["span_labels"], len(sample['original_text'])))
#batch_boundary = [sample['data_len_c'] for sample in batch_sample]
batch_span_label.append(sample["span_labels"])
seq_lens.append(len(sample['original_text']))
self.batch_max_lenght = max(seq_lens)
if self.batch_max_lenght > self.encode_max_length : self.batch_max_lenght = self.encode_max_length
encoded = self.tokenizer(batch_text, truncation=True, max_length=512, padding='max_length', stride=self.stride, return_overflowing_tokens=True, return_tensors="pt", return_offsets_mapping=True)
#encoded = self.tokenizer(batch_text, truncation=True, padding=True, return_tensors="pt", max_length=self.max_length)
encoded['labels'], batch_encode_seq_lens = self.encode_lable(encoded, batch_table)
encoded["crf_mask"] = self.create_crf_mask(batch_encode_seq_lens)
#encoded["boundary"] = batch_boundary
#encoded["boundary_encode"] = self.boundary_encoded(encoded, batch_boundary)
encoded["span_labels"] = batch_span_label
encoded["batch_text"] = batch_text
return encoded