import torch from torch.utils.data import Dataset from args import args class items_dataset(Dataset): def __init__(self, tokenizer, data_set, label_dict, stride=0, max_length=args.max_length): self.data_set = data_set self.tokenizer = tokenizer self.label_dict = label_dict self.max_length = max_length self.encode_max_length = max_length-2 #[CLS] [SEP] self.batch_max_lenght = max_length self.stride = stride def __getitem__(self, index): result = self.data_set[index] return result def __len__(self): return len(self.data_set) def create_label_list(self, span_label, max_len): #ans = [] table = torch.zeros(max_len) for start, end in span_label: table[start:end] = 2 #"I" table[start] = 1 #"B" """ for label in table.tolist(): if label == 0: ans.append("O") elif label == 1: ans.append("B") elif label == 2: ans.append("I") else: print("error") """ return table def encode_lable(self, encoded, batch_table): batch_encode_seq_lens = [] sample_mapping = encoded["overflow_to_sample_mapping"] offset_mapping = encoded["offset_mapping"] encoded_label = torch.zeros(len(sample_mapping) ,self.encode_max_length, dtype=torch.long) for id_in_batch in range(len(sample_mapping)): encode_len=0 table = batch_table[sample_mapping[id_in_batch]] for i in range(self.max_length): char_start, char_end = offset_mapping[id_in_batch][i] # ignore [CLS], [SEP] token if char_start!=0 or char_end!=0: encode_len+=1 #print(encoded_label.shape, table.shape) encoded_label[id_in_batch][i-1] = table[char_start].long() batch_encode_seq_lens.append(encode_len) return encoded_label, batch_encode_seq_lens def create_crf_mask(self, batch_encode_seq_lens): mask = torch.zeros(len(batch_encode_seq_lens), self.encode_max_length, dtype=torch.bool) #print(len(batch_table), len(batch_lens), seq_lens, batch_lens) for i, batch_len in enumerate(batch_encode_seq_lens): mask[i][:batch_len]=True return mask def boundary_encoded(self, encodings, batch_boundary): batch_boundary_encoded = [] for batch_id, span_labels in enumerate(batch_boundary): boundary_encoded = [] end = 0 for boundary in span_labels: end += boundary encoded_end = encodings[batch_id].char_to_token(end-1) # tmp_end = end while encoded_end==None and tmp_end>0: tmp_end-=1 encoded_end = encodings[batch_id].char_to_token(tmp_end-1) if end!=None: encoded_end+=1 if encoded_end>self.encode_max_length: boundary_encoded.append(self.encode_max_length) break else: boundary_encoded.append(encoded_end) for i in range(len(boundary_encoded)-1, 0, -1): boundary_encoded[i]=boundary_encoded[i]-boundary_encoded[i-1] batch_boundary_encoded.append(boundary_encoded) return batch_boundary_encoded def cal_agreement_span(self, agreement_table, min_agree=2, max_agree=3): """ find the spans from agreement table """ ans_span=[] start, end =(0, 0) pre_p = agreement_table[0] for i, word_agreement in enumerate(agreement_table): curr_p = word_agreement if curr_p != pre_p: if start != end: ans_span.append([start, end]) start=i end=i pre_p = curr_p if word_agreement self.encode_max_length : self.batch_max_lenght = self.encode_max_length encoded = self.tokenizer(batch_text, truncation=True, max_length=512, padding='max_length', stride=self.stride, return_overflowing_tokens=True, return_tensors="pt", return_offsets_mapping=True) #encoded = self.tokenizer(batch_text, truncation=True, padding=True, return_tensors="pt", max_length=self.max_length) encoded['labels'], batch_encode_seq_lens = self.encode_lable(encoded, batch_table) encoded["crf_mask"] = self.create_crf_mask(batch_encode_seq_lens) #encoded["boundary"] = batch_boundary #encoded["boundary_encode"] = self.boundary_encoded(encoded, batch_boundary) encoded["span_labels"] = batch_span_label encoded["batch_text"] = batch_text return encoded