File size: 3,523 Bytes
9fdc3cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from args import args, config
from tqdm import tqdm
from items_dataset import items_dataset

def test_predict(test_loader, device, model, min_label=1, max_label=3):
  model.eval()
  result = []
  
  for i, test_batch in enumerate(tqdm(test_loader)):
    batch_text = test_batch['batch_text']
    input_ids = test_batch['input_ids'].to(device)
    token_type_ids = test_batch['token_type_ids'].to(device)
    attention_mask = test_batch['attention_mask'].to(device)
    #labels = test_batch['labels'].to(device)
    crf_mask = test_batch["crf_mask"].to(device)
    sample_mapping = test_batch["overflow_to_sample_mapping"]
    output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)
    if args.use_crf:
      prediction = model.crf.decode(output[0], crf_mask)
    else:
      prediction = torch.max(output[0], -1).indices

    #make result of every sample
    sample_id = -1
    sample_result= {"text_a" : test_batch['batch_text'][0]}
    for batch_id in range(len(sample_mapping)):
        change_sample = False
        if sample_id != sample_mapping[batch_id]: change_sample = True
        #print(i, id)
        if change_sample:
          sample_id = sample_mapping[batch_id]
          sample_result= {"text_a" : test_batch['batch_text'][sample_id]}
          decode_span_table = torch.zeros(len(test_batch['batch_text'][sample_id]))

        spans = items_dataset.cal_agreement_span(None, agreement_table=prediction[batch_id], min_agree=min_label, max_agree=max_label)
        #decode spans
        for span in spans:
            #print(span)
            if span[0]==0: span[0]+=1
            if span[1]==1: span[1]+=1

            while(True):
              start = test_batch[batch_id].token_to_chars(span[0])
              if start != None or span[0]>=span[1]:
                break
              span[0]+=1
              
            while(True):
              end = test_batch[batch_id].token_to_chars(span[1])
              if end != None or span[0]>=span[1]:
                break
              span[1]-=1

            if span[0]<span[1]:
              de_start = test_batch[batch_id].token_to_chars(span[0])[0]
              de_end = test_batch[batch_id].token_to_chars(span[1]-1)[0]
              #print(de_start, de_end)
              #if(de_start>512): print(de_start, de_end)
              decode_span_table[de_start:de_end]=2 #insite
              decode_span_table[de_start]=1 #begin
        if change_sample:
          sample_result["predict_span_table"] = decode_span_table
          #sample_result["boundary"] = test_batch["boundary"][id]
          result.append(sample_result)
  model.train()
  return result

def add_sentence_table(result):
  
  pattern =":;。,?!~!: "
  for sample in result:
    boundary_list = []
    for i, char in enumerate(sample['text_a']):
      if char in pattern:
        boundary_list.append(i)
    boundary_list.append(len(sample['text_a'])+1)
    start=0
    end =0
    pre_states =False
    sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
    for boundary in boundary_list:
      end = boundary
      if(sum(sample["predict_span_table"][start:end])>0):
        if pre_states:
          sample["predict_sentence_table"][start-1:end] = 2
        else: 
          sample["predict_sentence_table"][start:end] = 2
          sample["predict_sentence_table"][start] = 1
        pre_states=True
      else: pre_states =False
      start = end+1