Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -27,19 +27,17 @@ def preprocess(text):
|
|
27 |
def postprocess(text):
|
28 |
return text.replace("\\n", "\n").replace("\\t", "\t")
|
29 |
|
30 |
-
def answer(
|
31 |
'''sample:是否抽样。生成任务,可以设置为True;
|
32 |
top_p:0-1之间,生成的内容越多样
|
33 |
max_new_tokens=512 lost...'''
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
for i in range(len(history)):
|
38 |
-
preprocess_history.append(preprocess(history[i]))
|
39 |
|
40 |
-
|
41 |
-
print(
|
42 |
-
encoding = tokenizer(text=
|
43 |
if not sample:
|
44 |
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
|
45 |
else:
|
@@ -67,15 +65,16 @@ def get_text():
|
|
67 |
input_text = st.text_input("用户: ","你好!", key="input")
|
68 |
return input_text
|
69 |
|
70 |
-
|
|
|
71 |
user_input = get_text()
|
72 |
-
|
73 |
|
74 |
if user_input:
|
75 |
-
output = answer(
|
76 |
st.session_state.past.append(user_input)
|
77 |
st.session_state.generated.append(output)
|
78 |
-
|
79 |
|
80 |
if st.session_state['generated']:
|
81 |
|
|
|
27 |
def postprocess(text):
|
28 |
return text.replace("\\n", "\n").replace("\\t", "\t")
|
29 |
|
30 |
+
def answer(user_history, bot_history, sample=True, top_p=1, temperature=0.7):
|
31 |
'''sample:是否抽样。生成任务,可以设置为True;
|
32 |
top_p:0-1之间,生成的内容越多样
|
33 |
max_new_tokens=512 lost...'''
|
34 |
|
35 |
+
context = "\n".join([f"用户:{user_history[i]}\n小元:{bot_history[i]}" for i in range(len(bot_history))])
|
36 |
+
input_text = context + "\n用户:" + user_history[-1] + "\n小元:"
|
|
|
|
|
37 |
|
38 |
+
input_text = preprocess(input_text)
|
39 |
+
print(input_text)
|
40 |
+
encoding = tokenizer(text=input_text, truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
|
41 |
if not sample:
|
42 |
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
|
43 |
else:
|
|
|
65 |
input_text = st.text_input("用户: ","你好!", key="input")
|
66 |
return input_text
|
67 |
|
68 |
+
user_history = []
|
69 |
+
bot_history = []
|
70 |
user_input = get_text()
|
71 |
+
user_history.append(user_input)
|
72 |
|
73 |
if user_input:
|
74 |
+
output = answer(user_history,bot_history)
|
75 |
st.session_state.past.append(user_input)
|
76 |
st.session_state.generated.append(output)
|
77 |
+
bot_history.append(output)
|
78 |
|
79 |
if st.session_state['generated']:
|
80 |
|