Tuchuanhuhuhu commited on
Commit
93def2f
1 Parent(s): f8a0305

川虎助理和川虎助理Pro支持流式输出

Browse files
config_example.json CHANGED
@@ -17,7 +17,7 @@
17
  "default_model": "gpt-3.5-turbo", // 默认模型
18
 
19
  //川虎助理设置
20
- "default_chuanhu_assistant_model": "gpt-4", //川虎助理使用的模型,可选gpt-3.5或者gpt-4
21
  "GOOGLE_CSE_ID": "", //谷歌搜索引擎ID,用于川虎助理Pro模式,获取方式请看 https://stackoverflow.com/questions/37083058/programmatically-searching-google-in-python-using-custom-search
22
  "GOOGLE_API_KEY": "", //谷歌API Key,用于川虎助理Pro模式
23
  "WOLFRAM_ALPHA_APPID": "", //Wolfram Alpha API Key,用于川虎助理Pro模式,获取方式请看 https://products.wolframalpha.com/api/
 
17
  "default_model": "gpt-3.5-turbo", // 默认模型
18
 
19
  //川虎助理设置
20
+ "default_chuanhu_assistant_model": "gpt-4", //川虎助理使用的模型,可选gpt-3.5-turbo或者gpt-4
21
  "GOOGLE_CSE_ID": "", //谷歌搜索引擎ID,用于川虎助理Pro模式,获取方式请看 https://stackoverflow.com/questions/37083058/programmatically-searching-google-in-python-using-custom-search
22
  "GOOGLE_API_KEY": "", //谷歌API Key,用于川虎助理Pro模式
23
  "WOLFRAM_ALPHA_APPID": "", //Wolfram Alpha API Key,用于川虎助理Pro模式,获取方式请看 https://products.wolframalpha.com/api/
modules/models/ChuanhuAgent.py CHANGED
@@ -1,8 +1,6 @@
1
  from langchain.chains.summarize import load_summarize_chain
2
- from langchain import OpenAI, PromptTemplate, LLMChain
3
  from langchain.chat_models import ChatOpenAI
4
- from langchain.text_splitter import CharacterTextSplitter
5
- from langchain.chains.mapreduce import MapReduceChain
6
  from langchain.prompts import PromptTemplate
7
  from langchain.text_splitter import TokenTextSplitter
8
  from langchain.embeddings import OpenAIEmbeddings
@@ -14,14 +12,23 @@ from langchain.agents import AgentType
14
  from langchain.docstore.document import Document
15
  from langchain.tools import BaseTool, StructuredTool, Tool, tool
16
  from langchain.callbacks.stdout import StdOutCallbackHandler
 
17
  from langchain.callbacks.manager import BaseCallbackManager
18
 
 
 
 
 
 
 
19
  from pydantic import BaseModel, Field
20
 
21
  import requests
22
  from bs4 import BeautifulSoup
 
 
23
 
24
- from .base_model import BaseLLMModel
25
  from ..config import default_chuanhu_assistant_model
26
  from ..presets import SUMMARIZE_PROMPT
27
  import logging
@@ -40,8 +47,9 @@ class ChuanhuAgent_Client(BaseLLMModel):
40
  self.text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=30)
41
  self.api_key = openai_api_key
42
  self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name=default_chuanhu_assistant_model)
 
43
  PROMPT = PromptTemplate(template=SUMMARIZE_PROMPT, input_variables=["text"])
44
- self.summarize_chain = load_summarize_chain(self.llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
45
  if "Pro" in self.model_name:
46
  self.tools = load_tools(["google-search-results-json", "llm-math", "arxiv", "wikipedia", "wolfram-alpha"], llm=self.llm)
47
  else:
@@ -96,13 +104,28 @@ class ChuanhuAgent_Client(BaseLLMModel):
96
  # create vectorstore
97
  db = FAISS.from_documents(texts, embeddings)
98
  retriever = db.as_retriever()
99
- qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=retriever)
100
  return qa.run(f"{question} Reply in 中文")
101
 
102
  def get_answer_at_once(self):
103
  question = self.history[-1]["content"]
104
- manager = BaseCallbackManager(handlers=[StdOutCallbackHandler()])
105
  # llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
106
- agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)
107
  reply = agent.run(input=f"{question} Reply in 简体中文")
108
  return reply, -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain.chains.summarize import load_summarize_chain
2
+ from langchain import PromptTemplate, LLMChain
3
  from langchain.chat_models import ChatOpenAI
 
 
4
  from langchain.prompts import PromptTemplate
5
  from langchain.text_splitter import TokenTextSplitter
6
  from langchain.embeddings import OpenAIEmbeddings
 
12
  from langchain.docstore.document import Document
13
  from langchain.tools import BaseTool, StructuredTool, Tool, tool
14
  from langchain.callbacks.stdout import StdOutCallbackHandler
15
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
16
  from langchain.callbacks.manager import BaseCallbackManager
17
 
18
+ from typing import Any, Dict, List, Optional, Union
19
+
20
+ from langchain.callbacks.base import BaseCallbackHandler
21
+ from langchain.input import print_text
22
+ from langchain.schema import AgentAction, AgentFinish, LLMResult
23
+
24
  from pydantic import BaseModel, Field
25
 
26
  import requests
27
  from bs4 import BeautifulSoup
28
+ from threading import Thread, Condition
29
+ from collections import deque
30
 
31
+ from .base_model import BaseLLMModel, CallbackToIterator, ChuanhuCallbackHandler
32
  from ..config import default_chuanhu_assistant_model
33
  from ..presets import SUMMARIZE_PROMPT
34
  import logging
 
47
  self.text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=30)
48
  self.api_key = openai_api_key
49
  self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name=default_chuanhu_assistant_model)
50
+ self.cheap_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name="gpt-3.5-turbo")
51
  PROMPT = PromptTemplate(template=SUMMARIZE_PROMPT, input_variables=["text"])
52
+ self.summarize_chain = load_summarize_chain(self.cheap_llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
53
  if "Pro" in self.model_name:
54
  self.tools = load_tools(["google-search-results-json", "llm-math", "arxiv", "wikipedia", "wolfram-alpha"], llm=self.llm)
55
  else:
 
104
  # create vectorstore
105
  db = FAISS.from_documents(texts, embeddings)
106
  retriever = db.as_retriever()
107
+ qa = RetrievalQA.from_chain_type(llm=self.cheap_llm, chain_type="stuff", retriever=retriever)
108
  return qa.run(f"{question} Reply in 中文")
109
 
110
  def get_answer_at_once(self):
111
  question = self.history[-1]["content"]
 
112
  # llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
113
+ agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
114
  reply = agent.run(input=f"{question} Reply in 简体中文")
115
  return reply, -1
116
+
117
+ def get_answer_stream_iter(self):
118
+ question = self.history[-1]["content"]
119
+ it = CallbackToIterator()
120
+ manager = BaseCallbackManager(handlers=[ChuanhuCallbackHandler(it.callback)])
121
+ def thread_func():
122
+ agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)
123
+ reply = agent.run(input=f"{question} Reply in 简体中文")
124
+ it.callback(reply)
125
+ it.finish()
126
+ t = Thread(target=thread_func)
127
+ t.start()
128
+ partial_text = ""
129
+ for value in it:
130
+ partial_text += value
131
+ yield partial_text
modules/models/base_model.py CHANGED
@@ -18,12 +18,85 @@ import asyncio
18
  import aiohttp
19
  from enum import Enum
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  from ..presets import *
22
  from ..index_func import *
23
  from ..utils import *
24
  from .. import shared
25
  from ..config import retrieve_proxy
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  class ModelType(Enum):
29
  Unknown = -1
 
18
  import aiohttp
19
  from enum import Enum
20
 
21
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
22
+ from langchain.callbacks.manager import BaseCallbackManager
23
+
24
+ from typing import Any, Dict, List, Optional, Union
25
+
26
+ from langchain.callbacks.base import BaseCallbackHandler
27
+ from langchain.input import print_text
28
+ from langchain.schema import AgentAction, AgentFinish, LLMResult
29
+ from threading import Thread, Condition
30
+ from collections import deque
31
+
32
  from ..presets import *
33
  from ..index_func import *
34
  from ..utils import *
35
  from .. import shared
36
  from ..config import retrieve_proxy
37
 
38
+ class CallbackToIterator:
39
+ def __init__(self):
40
+ self.queue = deque()
41
+ self.cond = Condition()
42
+ self.finished = False
43
+
44
+ def callback(self, result):
45
+ with self.cond:
46
+ self.queue.append(result)
47
+ self.cond.notify() # Wake up the generator.
48
+
49
+ def __iter__(self):
50
+ return self
51
+
52
+ def __next__(self):
53
+ with self.cond:
54
+ while not self.queue and not self.finished: # Wait for a value to be added to the queue.
55
+ self.cond.wait()
56
+ if not self.queue:
57
+ raise StopIteration()
58
+ return self.queue.popleft()
59
+
60
+ def finish(self):
61
+ with self.cond:
62
+ self.finished = True
63
+ self.cond.notify() # Wake up the generator if it's waiting.
64
+
65
+ class ChuanhuCallbackHandler(BaseCallbackHandler):
66
+
67
+ def __init__(self, callback) -> None:
68
+ """Initialize callback handler."""
69
+ self.callback = callback
70
+
71
+ def on_agent_action(
72
+ self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
73
+ ) -> Any:
74
+ self.callback(action.log)
75
+
76
+ def on_tool_end(
77
+ self,
78
+ output: str,
79
+ color: Optional[str] = None,
80
+ observation_prefix: Optional[str] = None,
81
+ llm_prefix: Optional[str] = None,
82
+ **kwargs: Any,
83
+ ) -> None:
84
+ """If not the final action, print out observation."""
85
+ if observation_prefix is not None:
86
+ self.callback(f"\n\n{observation_prefix}")
87
+ self.callback(output)
88
+ if llm_prefix is not None:
89
+ self.callback(f"\n\n{llm_prefix}")
90
+
91
+ def on_agent_finish(
92
+ self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
93
+ ) -> None:
94
+ self.callback(f"{finish.log}\n\n")
95
+
96
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
97
+ """Run on new LLM token. Only available when streaming is enabled."""
98
+ self.callback(token)
99
+
100
 
101
  class ModelType(Enum):
102
  Unknown = -1