scutcyr commited on
Commit
4a4d251
1 Parent(s): 7556b80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -13
app.py CHANGED
@@ -11,13 +11,7 @@ import torch
11
  import streamlit as st
12
  from streamlit_chat import message
13
 
14
- # 下载模型
15
- tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
16
- model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
17
- # 修改colab笔记本设置为gpu,推理更快
18
- device = torch.device('cpu')
19
- model.to(device)
20
- print('Model Load done!')
21
 
22
  def preprocess(text):
23
  text = text.replace("\n", "\\n").replace("\t", "\\t")
@@ -56,6 +50,25 @@ st.set_page_config(
56
  st.header("Chinese ChatBot - Demo")
57
  st.markdown("[Github](https://github.com/scutcyr)")
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if 'generated' not in st.session_state:
60
  st.session_state['generated'] = []
61
 
@@ -67,16 +80,16 @@ def get_text():
67
  input_text = st.text_input("用户: ","你好!", key="input")
68
  return input_text
69
 
70
- user_history = []
71
- bot_history = []
72
  user_input = get_text()
73
- user_history.append(user_input)
74
 
75
  if user_input:
76
- output = answer(user_history,bot_history)
77
  st.session_state.past.append(user_input)
 
78
  st.session_state.generated.append(output)
79
- bot_history.append(output)
80
 
81
  if st.session_state['generated']:
82
 
@@ -86,4 +99,16 @@ if st.session_state['generated']:
86
  for i in range(len(st.session_state['generated'])):
87
  message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
88
  message(st.session_state["generated"][i], key=str(i))
89
-
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import streamlit as st
12
  from streamlit_chat import message
13
 
14
+
 
 
 
 
 
 
15
 
16
  def preprocess(text):
17
  text = text.replace("\n", "\\n").replace("\t", "\\t")
 
50
  st.header("Chinese ChatBot - Demo")
51
  st.markdown("[Github](https://github.com/scutcyr)")
52
 
53
+
54
+ @st.cache_resource
55
+ def load_model():
56
+ model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1")
57
+ # 修改colab笔记本设置为gpu,推理更快
58
+ device = torch.device('cpu')
59
+ model.to(device)
60
+ print('Model Load done!')
61
+ return model
62
+
63
+ @st.cache_resource
64
+ def load_tokenizer():
65
+ tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
66
+ print('Tokenizer Load done!')
67
+ return tokenizer
68
+
69
+ model = load_model()
70
+ tokenizer = load_tokenizer()
71
+
72
  if 'generated' not in st.session_state:
73
  st.session_state['generated'] = []
74
 
 
80
  input_text = st.text_input("用户: ","你好!", key="input")
81
  return input_text
82
 
83
+ #user_history = []
84
+ #bot_history = []
85
  user_input = get_text()
86
+ #user_history.append(user_input)
87
 
88
  if user_input:
 
89
  st.session_state.past.append(user_input)
90
+ output = answer(st.session_state['past'],st.session_state["generated"])
91
  st.session_state.generated.append(output)
92
+ #bot_history.append(output)
93
 
94
  if st.session_state['generated']:
95
 
 
99
  for i in range(len(st.session_state['generated'])):
100
  message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
101
  message(st.session_state["generated"][i], key=str(i))
102
+
103
+
104
+ if st.button("清理模型缓存"):
105
+ # Clear values from *all* all in-memory and on-disk data caches:
106
+ # i.e. clear values from both square and cube
107
+ st.cache_resource.clear()
108
+ torch.cuda.empty_cache()
109
+
110
+ if st.button("清理对话缓存"):
111
+ # Clear values from *all* all in-memory and on-disk data caches:
112
+ # i.e. clear values from both square and cube
113
+ st.session_state['generated'] = []
114
+ st.session_state['past'] = []