scutcyr commited on
Commit
e917743
1 Parent(s): 22b5bfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -27,19 +27,17 @@ def preprocess(text):
27
  def postprocess(text):
28
  return text.replace("\\n", "\n").replace("\\t", "\t")
29
 
30
- def answer(history, sample=True, top_p=1, temperature=0.7):
31
  '''sample:是否抽样。生成任务,可以设置为True;
32
  top_p:0-1之间,生成的内容越多样
33
  max_new_tokens=512 lost...'''
34
 
35
- preprocess_history = []
36
-
37
- for i in range(len(history)):
38
- preprocess_history.append(preprocess(history[i]))
39
 
40
- #text = preprocess(text)
41
- print('用户: '+preprocess_history[-1])
42
- encoding = tokenizer(text=preprocess_history, truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
43
  if not sample:
44
  out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
45
  else:
@@ -67,15 +65,16 @@ def get_text():
67
  input_text = st.text_input("用户: ","你好!", key="input")
68
  return input_text
69
 
70
- history = []
 
71
  user_input = get_text()
72
- history.append(user_input)
73
 
74
  if user_input:
75
- output = answer(history)
76
  st.session_state.past.append(user_input)
77
  st.session_state.generated.append(output)
78
- history.append(output)
79
 
80
  if st.session_state['generated']:
81
 
 
27
  def postprocess(text):
28
  return text.replace("\\n", "\n").replace("\\t", "\t")
29
 
30
+ def answer(user_history, bot_history, sample=True, top_p=1, temperature=0.7):
31
  '''sample:是否抽样。生成任务,可以设置为True;
32
  top_p:0-1之间,生成的内容越多样
33
  max_new_tokens=512 lost...'''
34
 
35
+ context = "\n".join([f"用户:{user_history[i]}\n小元:{bot_history[i]}" for i in range(len(bot_history))])
36
+ input_text = context + "\n用户:" + user_history[-1] + "\n小元:"
 
 
37
 
38
+ input_text = preprocess(input_text)
39
+ print(input_text)
40
+ encoding = tokenizer(text=input_text, truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
41
  if not sample:
42
  out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
43
  else:
 
65
  input_text = st.text_input("用户: ","你好!", key="input")
66
  return input_text
67
 
68
+ user_history = []
69
+ bot_history = []
70
  user_input = get_text()
71
+ user_history.append(user_input)
72
 
73
  if user_input:
74
+ output = answer(user_history,bot_history)
75
  st.session_state.past.append(user_input)
76
  st.session_state.generated.append(output)
77
+ bot_history.append(output)
78
 
79
  if st.session_state['generated']:
80