add outline and BERTinference
Browse files- .gitattributes +0 -34
- BERT_inference.py +0 -0
- __pycache__/abstract.cpython-39.pyc +0 -0
- __pycache__/classification.cpython-39.pyc +0 -0
- __pycache__/inference.cpython-39.pyc +0 -0
- __pycache__/outline.cpython-39.pyc +0 -0
- __pycache__/util.cpython-39.pyc +0 -0
- abstruct.py → abstract.py +0 -0
- bert_model.pkl +0 -3
- classification.py +10 -15
- inference.py +75 -0
- outline.py +31 -0
- run.py +20 -7
- util.py +15 -3
.gitattributes
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BERT_inference.py
ADDED
File without changes
|
__pycache__/abstract.cpython-39.pyc
ADDED
Binary file (2.19 kB). View file
|
|
__pycache__/classification.cpython-39.pyc
ADDED
Binary file (2.67 kB). View file
|
|
__pycache__/inference.cpython-39.pyc
ADDED
Binary file (3.02 kB). View file
|
|
__pycache__/outline.cpython-39.pyc
ADDED
Binary file (834 Bytes). View file
|
|
__pycache__/util.cpython-39.pyc
ADDED
Binary file (2.77 kB). View file
|
|
abstruct.py → abstract.py
RENAMED
File without changes
|
bert_model.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dc61799e024f5d62f883b9e04886c749468378c19ed0185311a5ce4031ae8a5d
|
3 |
-
size 409205202
|
|
|
|
|
|
|
|
classification.py
CHANGED
@@ -15,18 +15,17 @@ def classify_by_topic(articles, central_topics):
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained(
|
16 |
"distilbert-base-multilingual-cased")
|
17 |
|
18 |
-
# 将一个句子转换为一个向量
|
19 |
def sentence_to_vector(sentence, context):
|
20 |
-
|
21 |
sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
|
22 |
tokens = tokenizer.encode_plus(
|
23 |
sentence, add_special_tokens=True, return_tensors="pt")
|
24 |
-
|
25 |
outputs = model(**tokens)
|
26 |
hidden_states = outputs.last_hidden_state
|
27 |
-
|
28 |
vector = np.squeeze(torch.mean(
|
29 |
-
hidden_states, dim=1).detach().numpy())
|
30 |
return vector
|
31 |
|
32 |
# 获取一个句子的上下文
|
@@ -51,26 +50,23 @@ def classify_by_topic(articles, central_topics):
|
|
51 |
nnext_sentence = sentences[index+2]
|
52 |
return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)
|
53 |
|
54 |
-
# 将每个文章句子和每个中心句子转换为向量
|
55 |
doc_vectors = [sentence_to_vector(sentence, get_context(
|
56 |
articles, i)) for i, sentence in enumerate(articles)]
|
57 |
topic_vectors = [sentence_to_vector(sentence, get_context(
|
58 |
central_topics, i)) for i, sentence in enumerate(central_topics)]
|
59 |
-
#
|
60 |
cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)
|
61 |
|
62 |
-
# print(cos_sim_matrix)
|
63 |
return cos_sim_matrix
|
64 |
|
65 |
-
#
|
66 |
def group_by_topic(articles, central_topics, similarity_matrix):
|
67 |
group = []
|
68 |
-
original_articles = articles.copy()
|
69 |
-
# 用原始的文章列表替换预处理后的文章列表
|
70 |
for article, similarity in zip(original_articles, similarity_matrix):
|
71 |
-
max_similarity = max(similarity)
|
72 |
-
max_index = similarity.tolist().index(max_similarity)
|
73 |
-
|
74 |
group.append((article, central_topics[max_index]))
|
75 |
|
76 |
return group
|
@@ -79,5 +75,4 @@ def classify_by_topic(articles, central_topics):
|
|
79 |
similarity_matrix = compute_similarity(articles, central_topics)
|
80 |
groups = group_by_topic(articles, central_topics, similarity_matrix)
|
81 |
|
82 |
-
# 返回分类后的列表
|
83 |
return groups
|
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained(
|
16 |
"distilbert-base-multilingual-cased")
|
17 |
|
|
|
18 |
def sentence_to_vector(sentence, context):
|
19 |
+
|
20 |
sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
|
21 |
tokens = tokenizer.encode_plus(
|
22 |
sentence, add_special_tokens=True, return_tensors="pt")
|
23 |
+
|
24 |
outputs = model(**tokens)
|
25 |
hidden_states = outputs.last_hidden_state
|
26 |
+
|
27 |
vector = np.squeeze(torch.mean(
|
28 |
+
hidden_states, dim=1).detach().numpy())
|
29 |
return vector
|
30 |
|
31 |
# 获取一个句子的上下文
|
|
|
50 |
nnext_sentence = sentences[index+2]
|
51 |
return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)
|
52 |
|
|
|
53 |
doc_vectors = [sentence_to_vector(sentence, get_context(
|
54 |
articles, i)) for i, sentence in enumerate(articles)]
|
55 |
topic_vectors = [sentence_to_vector(sentence, get_context(
|
56 |
central_topics, i)) for i, sentence in enumerate(central_topics)]
|
57 |
+
# 计算余弦相似度矩阵
|
58 |
cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)
|
59 |
|
|
|
60 |
return cos_sim_matrix
|
61 |
|
62 |
+
# 分类文章
|
63 |
def group_by_topic(articles, central_topics, similarity_matrix):
|
64 |
group = []
|
65 |
+
original_articles = articles.copy()
|
|
|
66 |
for article, similarity in zip(original_articles, similarity_matrix):
|
67 |
+
max_similarity = max(similarity)
|
68 |
+
max_index = similarity.tolist().index(max_similarity)
|
69 |
+
|
70 |
group.append((article, central_topics[max_index]))
|
71 |
|
72 |
return group
|
|
|
75 |
similarity_matrix = compute_similarity(articles, central_topics)
|
76 |
groups = group_by_topic(articles, central_topics, similarity_matrix)
|
77 |
|
|
|
78 |
return groups
|
inference.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import transformers
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import cuda
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
def encoder(max_len,text):
|
11 |
+
|
12 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
13 |
+
tokenizer = tokenizer(
|
14 |
+
text,
|
15 |
+
padding = True,
|
16 |
+
truncation = True,
|
17 |
+
max_length = max_len,
|
18 |
+
return_tensors='pt'
|
19 |
+
)
|
20 |
+
input_ids = tokenizer['input_ids']
|
21 |
+
token_type_ids = tokenizer['token_type_ids']
|
22 |
+
attention_mask = tokenizer['attention_mask']
|
23 |
+
return input_ids,token_type_ids,attention_mask
|
24 |
+
|
25 |
+
|
26 |
+
def predict(model,device,text):
|
27 |
+
model.to(device)
|
28 |
+
model.eval()
|
29 |
+
with torch.no_grad():
|
30 |
+
input_ids,token_type_ids,attention_mask = encoder(512,text)
|
31 |
+
input_ids,token_type_ids,attention_mask=input_ids.to(device),token_type_ids.to(device),attention_mask.to(device)
|
32 |
+
out_put = model(input_ids,token_type_ids,attention_mask)
|
33 |
+
# pre_numpy = out_put.cpu().numpy().tolist()
|
34 |
+
probs = torch.nn.functional.softmax(out_put).detach().cpu().numpy().tolist()
|
35 |
+
# print(probs)
|
36 |
+
return probs[0][1]
|
37 |
+
|
38 |
+
|
39 |
+
class BertClassificationModel(nn.Module):
|
40 |
+
def __init__(self):
|
41 |
+
super(BertClassificationModel, self).__init__()
|
42 |
+
pretrained_weights="bert-base-chinese"
|
43 |
+
self.bert = transformers.BertModel.from_pretrained(pretrained_weights)
|
44 |
+
for param in self.bert.parameters():
|
45 |
+
param.requires_grad = True
|
46 |
+
self.dense = nn.Linear(768, 3)
|
47 |
+
|
48 |
+
def forward(self, input_ids,token_type_ids,attention_mask):
|
49 |
+
bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask)
|
50 |
+
bert_cls_hidden_state = bert_output[1]
|
51 |
+
linear_output = self.dense(bert_cls_hidden_state)
|
52 |
+
return linear_output
|
53 |
+
|
54 |
+
def inference_matrix(topics):
|
55 |
+
device = torch.device('cuda' if cuda.is_available() else 'cpu')
|
56 |
+
load_path = "TSA/bert_model.pkl"
|
57 |
+
model = torch.load(load_path,map_location=torch.device(device))
|
58 |
+
matrix = np.zeros([len(topics),len(topics)],dtype=float)
|
59 |
+
for i,i_text in enumerate(topics):
|
60 |
+
for j,j_text in enumerate(topics):
|
61 |
+
if(i == j):
|
62 |
+
matrix[i][j] = 1
|
63 |
+
else:
|
64 |
+
test = i_text+" 是否包含 "+j_text
|
65 |
+
outputs = predict(model,device,test)
|
66 |
+
# outputs = model(ids, mask,token_type_ids)
|
67 |
+
# print(outputs)
|
68 |
+
matrix[i][j] = outputs
|
69 |
+
|
70 |
+
return matrix
|
71 |
+
if __name__ == "__main__":
|
72 |
+
|
73 |
+
print("yes")
|
74 |
+
topics = ['在本次报告中我将介绍分布式并行加速算法模型架构内存和计算优化以及集群架构等关键技术', '在现代机器学习任务中大模型训练已成为解决复杂问题的重要手段', '首先分布式并行加速策略包括数据并行模型并行流水线并行和张量并行等四种方式', '选择合适的集群架构是实现大模型的分布式训练的关键', '这些策略帮助我们将训练数据和模型分布到多个设备上以加速大模型训练过程']
|
75 |
+
print(inference_matrix(topics))
|
outline.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
def passage_outline(matrix,sentences):
|
6 |
+
# matrix = np.array([[1.0, 0.8, 0.2, 0.1],
|
7 |
+
# [0.8, 1.0, 0.3, 0.2],
|
8 |
+
# [0.2, 0.3, 1.0, 0.9],
|
9 |
+
# [0.1, 0.2, 0.9, 1.0]])
|
10 |
+
# sentences = ["主题句子1", "主题句子2", "主题句子3", "主题句子4"]
|
11 |
+
|
12 |
+
Z = linkage(matrix, method="average")
|
13 |
+
|
14 |
+
labels = fcluster(Z, t=0.5, criterion="distance")
|
15 |
+
|
16 |
+
|
17 |
+
# 根据簇标签和主题句子生成文章结构
|
18 |
+
structure = {}
|
19 |
+
for label, sentence in zip(labels, sentences):
|
20 |
+
if label not in structure:
|
21 |
+
structure[label] = []
|
22 |
+
structure[label].append(sentence)
|
23 |
+
outline = ""
|
24 |
+
outline_list = []
|
25 |
+
for key in sorted(structure.keys()):
|
26 |
+
outline_list.append(f"主题{key}:")
|
27 |
+
outline = outline+f"主题{key}:\n"
|
28 |
+
for sentence in structure[key]:
|
29 |
+
outline_list.append(sentence)
|
30 |
+
outline = outline+f"- {sentence}\n"
|
31 |
+
return outline,outline_list
|
run.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import util
|
2 |
-
import
|
3 |
import classification
|
4 |
-
|
|
|
|
|
5 |
# input:file/text,topic_num,max_length,output_choice
|
6 |
# output:file/text/topic_sentence
|
7 |
|
@@ -29,15 +31,26 @@ article = util.seg(text)
|
|
29 |
print(article)
|
30 |
|
31 |
sentences = [util.clean_text(sentence) for sentence in article]
|
32 |
-
|
|
|
33 |
print(central_sentences)
|
|
|
34 |
groups = classification.classify_by_topic(article, central_sentences)
|
35 |
print(groups)
|
36 |
|
37 |
groups = util.article_to_group(groups, central_sentences)
|
38 |
-
|
|
|
39 |
# ans:
|
40 |
-
# {(main_sentence,
|
41 |
-
for i in
|
42 |
print(i)
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import util
|
2 |
+
import abstract
|
3 |
import classification
|
4 |
+
import inference
|
5 |
+
import outline
|
6 |
+
from inference import BertClassificationModel
|
7 |
# input:file/text,topic_num,max_length,output_choice
|
8 |
# output:file/text/topic_sentence
|
9 |
|
|
|
31 |
print(article)
|
32 |
|
33 |
sentences = [util.clean_text(sentence) for sentence in article]
|
34 |
+
|
35 |
+
central_sentences = abstract.abstruct_main(sentences, topic_num)
|
36 |
print(central_sentences)
|
37 |
+
|
38 |
groups = classification.classify_by_topic(article, central_sentences)
|
39 |
print(groups)
|
40 |
|
41 |
groups = util.article_to_group(groups, central_sentences)
|
42 |
+
|
43 |
+
title_dict,title = util.generation(groups, max_length)
|
44 |
# ans:
|
45 |
+
# {Ai_abstruct:(main_sentence,paragraph)}
|
46 |
+
for i in title_dict.items():
|
47 |
print(i)
|
48 |
+
|
49 |
+
matrix = inference.inference_matrix(title)
|
50 |
+
print(matrix)
|
51 |
+
|
52 |
+
text_outline,outline_list = outline.passage_outline(matrix,title)
|
53 |
+
print(text_outline)
|
54 |
+
|
55 |
+
output = util.formate_text(title_dict,outline_list)
|
56 |
+
print (output)
|
util.py
CHANGED
@@ -3,10 +3,12 @@ import jieba
|
|
3 |
import re
|
4 |
import requests
|
5 |
import backoff
|
|
|
6 |
|
7 |
|
8 |
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException)
|
9 |
def post_url(url, headers, payload):
|
|
|
10 |
response = requests.request("POST", url, headers=headers, data=payload)
|
11 |
return response
|
12 |
|
@@ -55,7 +57,7 @@ def generation(para, max_length):
|
|
55 |
|
56 |
url = "https://aip.baidubce.com/rpc/2.0/nlp/v1/news_summary?charset=UTF-8&access_token=" + get_access_token()
|
57 |
topic = {}
|
58 |
-
|
59 |
for i, (j, k) in enumerate(para.items()):
|
60 |
input_text = k
|
61 |
# print(k)
|
@@ -71,5 +73,15 @@ def generation(para, max_length):
|
|
71 |
response = post_url(url, headers, payload)
|
72 |
text_dict = json.loads(response.text)
|
73 |
# print(text_dict)
|
74 |
-
topic[
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import re
|
4 |
import requests
|
5 |
import backoff
|
6 |
+
import time
|
7 |
|
8 |
|
9 |
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException)
|
10 |
def post_url(url, headers, payload):
|
11 |
+
time.sleep(0.3)
|
12 |
response = requests.request("POST", url, headers=headers, data=payload)
|
13 |
return response
|
14 |
|
|
|
57 |
|
58 |
url = "https://aip.baidubce.com/rpc/2.0/nlp/v1/news_summary?charset=UTF-8&access_token=" + get_access_token()
|
59 |
topic = {}
|
60 |
+
Ai_abstract = []
|
61 |
for i, (j, k) in enumerate(para.items()):
|
62 |
input_text = k
|
63 |
# print(k)
|
|
|
73 |
response = post_url(url, headers, payload)
|
74 |
text_dict = json.loads(response.text)
|
75 |
# print(text_dict)
|
76 |
+
topic[text_dict['summary']] = (j, k)
|
77 |
+
Ai_abstract.append(text_dict['summary'])
|
78 |
+
return topic,Ai_abstract
|
79 |
+
def formate_text(title_dict,outline_list):
|
80 |
+
formated = []
|
81 |
+
for each in outline_list:
|
82 |
+
if(each not in title_dict.keys()):
|
83 |
+
formated.append(f"# {each}")
|
84 |
+
if(each in title_dict.keys()):
|
85 |
+
formated.append(f"## {each}")
|
86 |
+
formated.append(title_dict[each][1])
|
87 |
+
return formated
|