import torch from args import args, config from tqdm import tqdm from items_dataset import items_dataset def test_predict(test_loader, device, model, min_label=1, max_label=3): model.eval() result = [] for i, test_batch in enumerate(tqdm(test_loader)): batch_text = test_batch['batch_text'] input_ids = test_batch['input_ids'].to(device) token_type_ids = test_batch['token_type_ids'].to(device) attention_mask = test_batch['attention_mask'].to(device) #labels = test_batch['labels'].to(device) crf_mask = test_batch["crf_mask"].to(device) sample_mapping = test_batch["overflow_to_sample_mapping"] output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask) if args.use_crf: prediction = model.crf.decode(output[0], crf_mask) else: prediction = torch.max(output[0], -1).indices #make result of every sample sample_id = -1 sample_result= {"text_a" : test_batch['batch_text'][0]} for batch_id in range(len(sample_mapping)): change_sample = False if sample_id != sample_mapping[batch_id]: change_sample = True #print(i, id) if change_sample: sample_id = sample_mapping[batch_id] sample_result= {"text_a" : test_batch['batch_text'][sample_id]} decode_span_table = torch.zeros(len(test_batch['batch_text'][sample_id])) spans = items_dataset.cal_agreement_span(None, agreement_table=prediction[batch_id], min_agree=min_label, max_agree=max_label) #decode spans for span in spans: #print(span) if span[0]==0: span[0]+=1 if span[1]==1: span[1]+=1 while(True): start = test_batch[batch_id].token_to_chars(span[0]) if start != None or span[0]>=span[1]: break span[0]+=1 while(True): end = test_batch[batch_id].token_to_chars(span[1]) if end != None or span[0]>=span[1]: break span[1]-=1 if span[0]512): print(de_start, de_end) decode_span_table[de_start:de_end]=2 #insite decode_span_table[de_start]=1 #begin if change_sample: sample_result["predict_span_table"] = decode_span_table #sample_result["boundary"] = test_batch["boundary"][id] result.append(sample_result) model.train() return result def add_sentence_table(result, pattern =":;。,,?!~!: ", threshold_num=5, threshold_rate=0.5): for sample in result: boundary_list = [] for i, char in enumerate(sample['text_a']): if char in pattern: boundary_list.append(i) boundary_list.append(len(sample['text_a'])+1) start=0 end =0 fist_sentence = True sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"])) for boundary in boundary_list: end = boundary predict_num = sum(sample["predict_span_table"][start:end]>0) sentence_num = len(sample["predict_span_table"][start:end]) if(predict_num > threshold_num) or (predict_num > sentence_num*threshold_rate): if fist_sentence: sample["predict_sentence_table"][start:end] = 2 sample["predict_sentence_table"][start] = 1 fist_sentence=False else: sample["predict_sentence_table"][start-1:end] = 2 else: fist_sentence =True start = end+1 def add_doc_id(result, test_data): #make dict {'text_a':"docid"} text_to_id = dict() for sample in test_data: text_to_id[sample["text_a"]] = sample["docid"] #add doc_id for sample in result: sample["docid"] = text_to_id[sample["text_a"]]