File size: 8,654 Bytes
9fdc3cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea140fb
9fdc3cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from args import args, config
from items_dataset import items_dataset
from torch.utils.data import DataLoader
from models import Model_Crf, Model_Softmax
from transformers import AutoTokenizer
from tqdm import tqdm
import prediction
import torch
import math

directory = args.SAVE_MODEL_PATH
model_name = "roberta_CRF.pt"
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
model_crf = Model_Crf(config).to(device)
model_crf.load_state_dict(
    state_dict=torch.load(directory + model_name, map_location=device)
)

model_name = "roberta_softmax.pt"
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
model_roberta = Model_Softmax(config).to(device)
model_roberta.load_state_dict(
    state_dict=torch.load(directory + model_name, map_location=device)
)


def prepare_span_data(dataset):
    for sample in dataset:
        spans = items_dataset.cal_agreement_span(
            None,
            agreement_table=sample["predict_sentence_table"],
            min_agree=1,
            max_agree=2,
        )
        sample["span_labels"] = spans
        sample["original_text"] = sample["text_a"]
        del sample["text_a"]


def rank_spans(test_loader, device, model, reverse=True):
    """Calculate each span probability by e**(word average log likelihood)"""
    model.eval()
    result = []

    for i, test_batch in enumerate(tqdm(test_loader)):
        batch_text = test_batch["batch_text"]
        input_ids = test_batch["input_ids"].to(device)
        token_type_ids = test_batch["token_type_ids"].to(device)
        attention_mask = test_batch["attention_mask"].to(device)
        labels = test_batch["labels"]
        crf_mask = test_batch["crf_mask"].to(device)
        sample_mapping = test_batch["overflow_to_sample_mapping"]
        output = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=None,
            crf_mask=crf_mask,
        )
        output = torch.nn.functional.softmax(output[0], dim=-1)

        # make result of every sample
        sample_id = 0
        sample_result = {
            "original_text": test_batch["batch_text"][sample_id],
            "span_ranked": [],
        }
        for batch_id in range(len(sample_mapping)):
            change_sample = False

            # make sure status
            if sample_id != sample_mapping[batch_id]:
                change_sample = True
            if change_sample:
                sample_id = sample_mapping[batch_id]
                result.append(sample_result)
                sample_result = {
                    "original_text": test_batch["batch_text"][sample_id],
                    "span_ranked": [],
                }

            encoded_spans = items_dataset.cal_agreement_span(
                None, agreement_table=labels[batch_id], min_agree=1, max_agree=2
            )
            # print(encoded_spans)
            for encoded_span in encoded_spans:
                # calculate span loss
                span_lenght = encoded_span[1] - encoded_span[0]
                # print(span_lenght)
                span_prob_table = torch.log(
                    output[batch_id][encoded_span[0] : encoded_span[1]]
                )
                if (
                    not change_sample and encoded_span[0] == 0 and batch_id != 0
                ):  # span cross two tensors
                    span_loss += span_prob_table[0][1]  # Begin
                else:
                    span_loss = span_prob_table[0][1]  # Begin
                for token_id in range(1, span_prob_table.shape[0]):
                    span_loss += span_prob_table[token_id][2]  # Inside
                span_loss /= span_lenght

                # span decode
                decode_start = test_batch[batch_id].token_to_chars(encoded_span[0] + 1)[
                    0
                ]
                decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0] + 1
                # print((decode_start, decode_end))
                span_text = test_batch["batch_text"][sample_mapping[batch_id]][
                    decode_start:decode_end
                ]
                if (
                    not change_sample and encoded_span[0] == 0 and batch_id != 0
                ):  # span cross two tensors
                    presample = sample_result["span_ranked"].pop(-1)
                    sample_result["span_ranked"].append(
                        [presample[0] + span_text, math.e ** float(span_loss)]
                    )
                else:
                    sample_result["span_ranked"].append(
                        [span_text, math.e ** float(span_loss)]
                    )
        result.append(sample_result)

    # sorted spans by probability
    # for sample in result:
    #     sample["span_ranked"] = sorted(
    #         sample["span_ranked"], key=lambda x: x[1], reverse=reverse
    #     )
    return result


def predict_single(text):
    input_dict = [{"span_labels": []}]
    input_dict[0]["original_text"] = text
    tokenizer = AutoTokenizer.from_pretrained(
        args.pre_model_name, add_prefix_space=True
    )
    prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)
    prediction_loader = DataLoader(
        prediction_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=prediction_dataset.collate_fn,
    )
    predict_data = prediction.test_predict(prediction_loader, device, model_crf)
    prediction.add_sentence_table(predict_data)

    prepare_span_data(predict_data)
    tokenizer = AutoTokenizer.from_pretrained(
        args.pre_model_name, add_prefix_space=True
    )
    prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)
    prediction_loader = DataLoader(
        prediction_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=prediction_dataset.collate_fn,
    )
    span_ranked = rank_spans(prediction_loader, device, model_roberta)
    # for sample in span_ranked:
    #     print(sample["original_text"])
    #     print(sample["span_ranked"])

    result = []
    sample = span_ranked[0]
    orig = sample["original_text"]
    cur = 0
    for s, score in sample["span_ranked"]:
        # print()
        # print('ORIG', repr(orig))
        # print('CCUR', repr(orig[cur:]))
        # print('SSSS', repr(s))
        # print()
        end = orig.index(s, cur)
        if cur != end:
            result.append([orig[cur:end], 0])
        result.append([s, score])
        cur = end + len(s)
    if cur < len(orig):
        result.append([orig[cur:], 0])
    return result


if __name__ == "__main__":
    s = """貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。  2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。  3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。  4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。  5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。"""
    print(predict_single(s))