Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,7 @@ import torch
|
|
11 |
import streamlit as st
|
12 |
from streamlit_chat import message
|
13 |
|
14 |
-
|
15 |
-
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
|
16 |
-
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
|
17 |
-
# 修改colab笔记本设置为gpu,推理更快
|
18 |
-
device = torch.device('cpu')
|
19 |
-
model.to(device)
|
20 |
-
print('Model Load done!')
|
21 |
|
22 |
def preprocess(text):
|
23 |
text = text.replace("\n", "\\n").replace("\t", "\\t")
|
@@ -56,6 +50,25 @@ st.set_page_config(
|
|
56 |
st.header("Chinese ChatBot - Demo")
|
57 |
st.markdown("[Github](https://github.com/scutcyr)")
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
if 'generated' not in st.session_state:
|
60 |
st.session_state['generated'] = []
|
61 |
|
@@ -67,16 +80,16 @@ def get_text():
|
|
67 |
input_text = st.text_input("用户: ","你好!", key="input")
|
68 |
return input_text
|
69 |
|
70 |
-
user_history = []
|
71 |
-
bot_history = []
|
72 |
user_input = get_text()
|
73 |
-
user_history.append(user_input)
|
74 |
|
75 |
if user_input:
|
76 |
-
output = answer(user_history,bot_history)
|
77 |
st.session_state.past.append(user_input)
|
|
|
78 |
st.session_state.generated.append(output)
|
79 |
-
bot_history.append(output)
|
80 |
|
81 |
if st.session_state['generated']:
|
82 |
|
@@ -86,4 +99,16 @@ if st.session_state['generated']:
|
|
86 |
for i in range(len(st.session_state['generated'])):
|
87 |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|
88 |
message(st.session_state["generated"][i], key=str(i))
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
import streamlit as st
|
12 |
from streamlit_chat import message
|
13 |
|
14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def preprocess(text):
|
17 |
text = text.replace("\n", "\\n").replace("\t", "\\t")
|
|
|
50 |
st.header("Chinese ChatBot - Demo")
|
51 |
st.markdown("[Github](https://github.com/scutcyr)")
|
52 |
|
53 |
+
|
54 |
+
@st.cache_resource
|
55 |
+
def load_model():
|
56 |
+
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
|
57 |
+
# 修改colab笔记本设置为gpu,推理更快
|
58 |
+
device = torch.device('cpu')
|
59 |
+
model.to(device)
|
60 |
+
print('Model Load done!')
|
61 |
+
return model
|
62 |
+
|
63 |
+
@st.cache_resource
|
64 |
+
def load_tokenizer():
|
65 |
+
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
|
66 |
+
print('Tokenizer Load done!')
|
67 |
+
return tokenizer
|
68 |
+
|
69 |
+
model = load_model()
|
70 |
+
tokenizer = load_tokenizer()
|
71 |
+
|
72 |
if 'generated' not in st.session_state:
|
73 |
st.session_state['generated'] = []
|
74 |
|
|
|
80 |
input_text = st.text_input("用户: ","你好!", key="input")
|
81 |
return input_text
|
82 |
|
83 |
+
#user_history = []
|
84 |
+
#bot_history = []
|
85 |
user_input = get_text()
|
86 |
+
#user_history.append(user_input)
|
87 |
|
88 |
if user_input:
|
|
|
89 |
st.session_state.past.append(user_input)
|
90 |
+
output = answer(st.session_state['past'],st.session_state["generated"])
|
91 |
st.session_state.generated.append(output)
|
92 |
+
#bot_history.append(output)
|
93 |
|
94 |
if st.session_state['generated']:
|
95 |
|
|
|
99 |
for i in range(len(st.session_state['generated'])):
|
100 |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|
101 |
message(st.session_state["generated"][i], key=str(i))
|
102 |
+
|
103 |
+
|
104 |
+
if st.button("清理模型缓存"):
|
105 |
+
# Clear values from *all* all in-memory and on-disk data caches:
|
106 |
+
# i.e. clear values from both square and cube
|
107 |
+
st.cache_resource.clear()
|
108 |
+
torch.cuda.empty_cache()
|
109 |
+
|
110 |
+
if st.button("清理对话缓存"):
|
111 |
+
# Clear values from *all* all in-memory and on-disk data caches:
|
112 |
+
# i.e. clear values from both square and cube
|
113 |
+
st.session_state['generated'] = []
|
114 |
+
st.session_state['past'] = []
|