scutcyr commited on
Commit
7646d66
1 Parent(s): c24230a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # install torch and tf
3
+ os.system('pip install transformers SentencePiece')
4
+ os.system('pip install torch')
5
+
6
+ # pip install streamlit-chat
7
+ os.system('pip install streamlit-chat')
8
+
9
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer
10
+ import torch
11
+
12
+ import streamlit as st
13
+ from streamlit_chat import message
14
+
15
+ # 下载模型
16
+ tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
17
+ model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
18
+ # 修改colab笔记本设置为gpu,推理更快
19
+ device = torch.device('cpu')
20
+ model.to(device)
21
+ print('Model Load done!')
22
+
23
+ def preprocess(text):
24
+ text = text.replace("\n", "\\n").replace("\t", "\\t")
25
+ return text
26
+
27
+ def postprocess(text):
28
+ return text.replace("\\n", "\n").replace("\\t", "\t")
29
+
30
+ def answer(history, sample=True, top_p=1, temperature=0.7):
31
+ '''sample:是否抽样。生成任务,可以设置为True;
32
+ top_p:0-1之间,生成的内容越多样
33
+ max_new_tokens=512 lost...'''
34
+
35
+ preprocess_history = []
36
+
37
+ for i in range(len(history)):
38
+ preprocess_history[i] = preprocess(text)
39
+
40
+ #text = preprocess(text)
41
+ #print('用户: '+text)
42
+ encoding = tokenizer(text=preprocess_history, truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
43
+ if not sample:
44
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
45
+ else:
46
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
47
+ out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
48
+ print('小元: '+postprocess(out_text[0]))
49
+ return postprocess(out_text[0])
50
+
51
+ st.set_page_config(
52
+ page_title="Chinese ChatBot - Demo",
53
+ page_icon=":robot:"
54
+ )
55
+
56
+ st.header("Chinese ChatBot - Demo")
57
+ st.markdown("[Github](https://github.com/scutcyr)")
58
+
59
+ if 'generated' not in st.session_state:
60
+ st.session_state['generated'] = []
61
+
62
+ if 'past' not in st.session_state:
63
+ st.session_state['past'] = []
64
+
65
+ def query(history):
66
+ inputs = tokenizer.dialogue_encode(
67
+ history, add_start_token_as_response=True, return_tensors=True, is_split_into_words=False
68
+ )
69
+ inputs["input_ids"] = inputs["input_ids"].astype("int64")
70
+ ids, scores = model.generate(
71
+ input_ids=inputs["input_ids"],
72
+ token_type_ids=inputs["token_type_ids"],
73
+ position_ids=inputs["position_ids"],
74
+ attention_mask=inputs["attention_mask"],
75
+ max_length=64,
76
+ min_length=1,
77
+ decode_strategy="sampling",
78
+ temperature=1.0,
79
+ top_k=5,
80
+ top_p=1.0,
81
+ num_beams=0,
82
+ length_penalty=1.0,
83
+ early_stopping=False,
84
+ num_return_sequences=20,
85
+ )
86
+ max_dec_len = 64
87
+ num_return_sequences = 20
88
+ bot_response = select_response(
89
+ ids, scores, tokenizer, max_dec_len, num_return_sequences, keep_space=False
90
+ )[0]
91
+ return bot_response
92
+
93
+ def get_text():
94
+ input_text = st.text_input("用户: ","你好!", key="input")
95
+ return input_text
96
+
97
+ history = []
98
+ user_input = get_text()
99
+ history.append(user_input)
100
+
101
+ if user_input:
102
+ output = answer(history)
103
+ st.session_state.past.append(user_input)
104
+ st.session_state.generated.append(output)
105
+ history.append(output)
106
+
107
+ if st.session_state['generated']:
108
+
109
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
110
+ message(st.session_state["generated"][i], key=str(i))
111
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')